Add 10 million timestep model, improve README, add diffculty/mode selector, score, make model training have instant graphs and also multiple envs for faster training, better plotting, improve RL model by including multiple players, better reard system, use EnemyFormation instead of single Enemy-es

This commit is contained in:
csd4ni3l
2025-11-16 21:27:21 +01:00
parent fb7b45b6df
commit dce64e5d3f
10 changed files with 644 additions and 273 deletions

View File

@@ -2,7 +2,9 @@ Fleet Commander is like Space Invaders but you are the enemy instead of the play
It uses AI (Reinforcement Learning) for the Player, and you, the Enemy has to defeat it. It uses AI (Reinforcement Learning) for the Player, and you, the Enemy has to defeat it.
You can train yourself, or use the default model which comes with the game. I know the game is too easy and is too simple, but please understand that doing RL isnt the easiest thing ever. I also did this so late so yeah.
You can train yourself, or use the default model(10 million timesteps) which comes with the game.
# Install steps: # Install steps:
@@ -18,4 +20,7 @@ You can train yourself, or use the default model which comes with the game.
- `pip3 install -r requirements.txt` - `pip3 install -r requirements.txt`
- `pip3 install torch --index-url https://download.pytorch.org/whl/cpu` - `pip3 install torch --index-url https://download.pytorch.org/whl/cpu`
- `pip3 install stable_baselines3` - `pip3 install stable_baselines3`
- `python3 run.py` - `python3 run.py`
# Disclaimer
AI assistance was used in this project, since i never did any RL work before. But every instance of AI code was heavily modified by me.

View File

@@ -1,16 +1,17 @@
import arcade, arcade.gui, random, time import arcade, arcade.gui, random, time
from utils.constants import button_style, ENEMY_ROWS, ENEMY_COLS, PLAYER_ATTACK_SPEED from utils.constants import button_style, ENEMY_ATTACK_SPEED, ENEMY_SPEED
from utils.preload import button_texture, button_hovered_texture from utils.preload import button_texture, button_hovered_texture
from stable_baselines3 import PPO from stable_baselines3 import PPO
from game.sprites import Enemy, Player, Bullet from game.sprites import EnemyFormation, Player, Bullet
class Game(arcade.gui.UIView): class Game(arcade.gui.UIView):
def __init__(self, pypresence_client): def __init__(self, pypresence_client, settings):
super().__init__() super().__init__()
self.settings = settings
self.pypresence_client = pypresence_client self.pypresence_client = pypresence_client
self.pypresence_client.update(state="Invading Space") self.pypresence_client.update(state="Invading Space")
@@ -18,42 +19,53 @@ class Game(arcade.gui.UIView):
self.spritelist = arcade.SpriteList() self.spritelist = arcade.SpriteList()
self.player = Player(self.window.width / 2 + random.randint(int(-self.window.width / 3), int(self.window.width / 3)), 100) # not actually player self.players = []
self.spritelist.append(self.player) for _ in range(settings["player_count"]):
self.players.append(Player(self.window.width / 2 + random.randint(int(-self.window.width / 3), int(self.window.width / 3)), 100)) # not actually player
self.last_player_shoot = time.perf_counter() # not actually player self.spritelist.append(self.players[-1])
self.model = PPO.load("invader_agent.zip") self.model = PPO.load("invader_agent.zip")
self.enemies: list[Enemy] = [] self.enemy_formation = EnemyFormation(self.window.width / 2 + random.randint(int(-self.window.width / 3), int(self.window.width / 3)), self.window.height * 0.9, self.spritelist, settings["enemy_rows"], settings["enemy_cols"])
self.player_bullets: list[Bullet] = [] self.player_bullets: list[Bullet] = []
self.enemy_bullets: list[Bullet] = [] self.enemy_bullets: list[Bullet] = []
self.summon_enemies() self.player_respawns = settings["player_respawns"]
self.enemy_respawns = settings["enemy_respawns"]
self.score = 0
self.last_enemy_shoot = time.perf_counter()
self.game_over = False
def on_show_view(self): def on_show_view(self):
super().on_show_view() super().on_show_view()
self.back_button = self.anchor.add(arcade.gui.UITextureButton(texture=button_texture, texture_hovered=button_hovered_texture, text='<--', style=button_style, width=100, height=50), anchor_x="left", anchor_y="top", align_x=5, align_y=-5) self.back_button = self.anchor.add(arcade.gui.UITextureButton(texture=button_texture, texture_hovered=button_hovered_texture, text='<--', style=button_style, width=100, height=50), anchor_x="left", anchor_y="top", align_x=5, align_y=-5)
self.back_button.on_click = lambda event: self.main_exit() self.back_button.on_click = lambda event: self.main_exit()
self.score_label = self.anchor.add(arcade.gui.UILabel("Score: 0", font_size=24), anchor_x="center", anchor_y="top")
def main_exit(self): def main_exit(self):
from menus.main import Main from menus.main import Main
self.window.show_view(Main(self.pypresence_client)) self.window.show_view(Main(self.pypresence_client))
def summon_enemies(self):
enemy_start_x = self.window.width * 0.15
enemy_start_y = self.window.height * 0.9
for row in range(ENEMY_ROWS):
for col in range(ENEMY_COLS):
enemy_sprite = Enemy(enemy_start_x + col * 100, enemy_start_y - row * 100)
self.spritelist.append(enemy_sprite)
self.enemies.append(enemy_sprite)
def on_update(self, delta_time): def on_update(self, delta_time):
for enemy in self.enemies: if self.game_over:
enemy.update() return
if self.window.keyboard[arcade.key.LEFT] or self.window.keyboard[arcade.key.A]:
self.enemy_formation.move(self.window.width, self.window.height, "x", -ENEMY_SPEED)
if self.window.keyboard[arcade.key.RIGHT] or self.window.keyboard[arcade.key.D]:
self.enemy_formation.move(self.window.width, self.window.height, "x", ENEMY_SPEED)
if self.window.keyboard[arcade.key.DOWN] or self.window.keyboard[arcade.key.S]:
self.enemy_formation.move(self.window.width, self.window.height, "y", -ENEMY_SPEED)
if self.window.keyboard[arcade.key.UP] or self.window.keyboard[arcade.key.W]:
self.enemy_formation.move(self.window.width, self.window.height, "y", ENEMY_SPEED)
if self.enemy_formation.enemies and self.window.keyboard[arcade.key.SPACE] and time.perf_counter() - self.last_enemy_shoot >= ENEMY_ATTACK_SPEED:
self.last_enemy_shoot = time.perf_counter()
enemy = self.enemy_formation.get_lowest_enemy()
self.shoot(enemy.center_x, enemy.center_y, -1)
bullets_to_remove = [] bullets_to_remove = []
@@ -62,18 +74,21 @@ class Game(arcade.gui.UIView):
bullet_hit = False bullet_hit = False
if bullet.direction_y == 1: if bullet.direction_y == 1:
for enemy in self.enemies: for enemy in self.enemy_formation.enemies:
if bullet.rect.intersection(enemy.rect): if bullet.rect.intersection(enemy.rect):
self.spritelist.remove(enemy) self.enemy_formation.remove_enemy(enemy)
self.enemies.remove(enemy)
self.player.target = None
bullets_to_remove.append(bullet) bullets_to_remove.append(bullet)
bullet_hit = True bullet_hit = True
break break
else: else:
if bullet.rect.intersection(self.player.rect): for player in self.players:
bullets_to_remove.append(bullet) if bullet.rect.intersection(player.rect):
bullet_hit = True self.spritelist.remove(player)
self.players.remove(player)
bullets_to_remove.append(bullet)
bullet_hit = True
self.score += 75
break
if not bullet_hit and bullet.center_y > self.window.height or bullet.center_y < 0: if not bullet_hit and bullet.center_y > self.window.height or bullet.center_y < 0:
bullets_to_remove.append(bullet) bullets_to_remove.append(bullet)
@@ -86,20 +101,40 @@ class Game(arcade.gui.UIView):
elif bullet_to_remove in self.player_bullets: elif bullet_to_remove in self.player_bullets:
self.player_bullets.remove(bullet_to_remove) self.player_bullets.remove(bullet_to_remove)
self.player.update(self.model, self.enemies, self.enemy_bullets, self.window.width, self.window.height) # not actually player for player in self.players:
player.update(self.model, self.enemy_formation, self.enemy_bullets, self.window.width, self.window.height, self.player_respawns / self.settings["player_respawns"], self.enemy_respawns / self.settings["enemy_respawns"]) # not actually player
if self.player.center_x > self.window.width: if player.center_x > self.window.width:
self.player.center_x = self.window.width player.center_x = self.window.width
elif self.player.center_x < 0: elif player.center_x < 0:
self.player.center_x = 0 player.center_x = 0
if self.player.shoot: if player.shoot:
self.player.shoot = False player.shoot = False
self.shoot(self.player.center_x, self.player.center_y, 1) self.shoot(player.center_x, player.center_y, 1)
if time.perf_counter() - self.last_player_shoot >= PLAYER_ATTACK_SPEED: if len(self.players) == 0:
self.last_player_shoot = time.perf_counter() if self.player_respawns > 0:
self.shoot(self.player.center_x, self.player.center_y, 1) self.player_respawns -= 1
for _ in range(self.settings["player_count"]):
self.players.append(Player(self.window.width / 2 + random.randint(int(-self.window.width / 3), int(self.window.width / 3)), 100)) # not actually player
self.spritelist.append(self.players[-1])
self.score += 300
else:
self.game_over = True
self.game_over_label = self.anchor.add(arcade.gui.UILabel("You (The enemies) won!", font_size=48), anchor_x="center", anchor_y="center")
elif len(self.enemy_formation.enemies) == 0:
if self.enemy_respawns > 0:
self.enemy_respawns -= 1
self.enemy_formation.create_formation(self.window.width / 2 + random.randint(int(-self.window.width / 3), int(self.window.width / 3)), self.window.height * 0.9)
else:
self.game_over = True
self.game_over_label = self.anchor.add(arcade.gui.UILabel("You lost! The Players win!", font_size=48), anchor_x="center", anchor_y="center")
self.score += 5 * delta_time
self.score_label.text = f"Score: {int(self.score)}"
def shoot(self, x, y, direction_y): def shoot(self, x, y, direction_y):
bullet = Bullet(x, y, direction_y) bullet = Bullet(x, y, direction_y)
@@ -112,11 +147,6 @@ class Game(arcade.gui.UIView):
bullets.append(bullet) bullets.append(bullet)
def on_key_press(self, symbol, modifiers):
if symbol == arcade.key.SPACE:
enemy = random.choice(self.enemies)
self.shoot(enemy.center_x, enemy.center_y, -1)
def on_draw(self): def on_draw(self):
super().on_draw() super().on_draw()

View File

@@ -1,10 +1,8 @@
import arcade, time import arcade, time, random, numpy as np
from stable_baselines3 import PPO from stable_baselines3 import PPO
import numpy as np from utils.constants import PLAYER_SPEED, BULLET_SPEED, BULLET_RADIUS, PLAYER_ATTACK_SPEED
from utils.constants import PLAYER_SPEED, BULLET_SPEED, BULLET_RADIUS, PLAYER_ATTACK_SPEED, ENEMY_COLS, ENEMY_ROWS
from utils.preload import player_texture, enemy_texture from utils.preload import player_texture, enemy_texture
class Bullet(arcade.Sprite): class Bullet(arcade.Sprite):
@@ -16,9 +14,100 @@ class Bullet(arcade.Sprite):
def update(self): def update(self):
self.center_y += self.direction_y * BULLET_SPEED self.center_y += self.direction_y * BULLET_SPEED
class Enemy(arcade.Sprite): class EnemyFormation():
def __init__(self, x, y): def __init__(self, start_x, start_y, spritelist: arcade.SpriteList | None, rows, cols):
super().__init__(enemy_texture, center_x=x, center_y=y) self.grid = [[] for _ in range(rows)]
self.start_x = start_x
self.start_y = start_y
self.rows = rows
self.cols = cols
self.spritelist = spritelist
self.create_formation()
def create_formation(self, start_x=None, start_y=None):
if start_x:
self.start_x = start_x
if start_y:
self.start_y = start_y
del self.grid
self.grid = [[] for _ in range(self.rows)]
for row in range(self.rows):
for col in range(self.cols):
enemy_sprite = arcade.Sprite(enemy_texture, center_x=self.start_x + col * 100, center_y=self.start_y - row * 100)
if self.spritelist:
self.spritelist.append(enemy_sprite)
self.grid[row].append(enemy_sprite)
def remove_enemy(self, enemy):
if self.spritelist and enemy not in self.spritelist:
return
for row in range(self.rows):
for col in range(self.cols):
if self.grid[row][col] == enemy:
self.grid[row][col] = None
if self.spritelist:
self.spritelist.remove(enemy)
return
def get_lowest_enemy(self):
valid_cols = []
for col in range(self.cols):
row = self.rows - 1
while row >= 0 and self.grid[row][col] is None:
row -= 1
if row >= 0:
valid_cols.append((col, row))
if not valid_cols:
return None
col, row = random.choice(valid_cols)
return self.grid[row][col]
def move(self, width, height, direction_type, value):
if direction_type == "x":
wall_hit = False
for enemy in self.enemies:
self.start_x += value
enemy.center_x += value
if enemy.center_x + enemy.width / 2 > width or enemy.center_x < enemy.width / 2:
wall_hit = True
if wall_hit:
for enemy in self.enemies:
self.start_x -= value
enemy.center_x -= value
else:
wall_hit = False
for enemy in self.enemies:
self.start_x += value
enemy.center_y += value
if enemy.center_y + enemy.height / 2 > height or enemy.center_y < enemy.height / 2:
wall_hit = True
if wall_hit:
for enemy in self.enemies:
self.start_y -= value
enemy.center_y -= value
@property
def center_x(self):
return self.start_x + (self.cols / 2) * 100
@property
def enemies(self):
return [col for row in self.grid for col in row if not col == None]
class Player(arcade.Sprite): # Not actually the player class Player(arcade.Sprite): # Not actually the player
def __init__(self, x, y): def __init__(self, x, y):
@@ -26,21 +115,19 @@ class Player(arcade.Sprite): # Not actually the player
self.last_target_change = time.perf_counter() self.last_target_change = time.perf_counter()
self.last_shoot = time.perf_counter() self.last_shoot = time.perf_counter()
self.target = None
self.shoot = False self.shoot = False
self.player_speed = 0 self.player_speed = 0
def update(self, model: PPO, enemies, bullets, width, height): def update(self, model: PPO, enemy_formation, bullets, width, height, player_respawns_norm, enemy_respawns_norm):
if enemies: if enemy_formation.enemies:
nearest_enemy = min(enemies, key=lambda e: abs(e.center_x - self.center_x)) nearest_enemy = min(enemy_formation.enemies, key=lambda e: abs(e.center_x - self.center_x))
enemy_x = (nearest_enemy.center_x - self.center_x) / width enemy_x = (nearest_enemy.center_x - self.center_x) / width
enemy_y = (nearest_enemy.center_y - self.center_y) / height enemy_y = (nearest_enemy.center_y - self.center_y) / height
else: else:
enemy_x = 2 enemy_x = 2
enemy_y = 2 enemy_y = 2
enemy_count = len(enemies) / float(max(1, ENEMY_ROWS * ENEMY_COLS)) enemy_count = len(enemy_formation.enemies) / float(max(1, enemy_formation.rows * enemy_formation.cols))
player_x_norm = self.center_x / width player_x_norm = self.center_x / width
curr_bullet = min(bullets, key=lambda b: abs(b.center_x - self.center_x) + abs(b.center_y - self.center_y)) if bullets else None curr_bullet = min(bullets, key=lambda b: abs(b.center_x - self.center_x) + abs(b.center_y - self.center_y)) if bullets else None
@@ -51,18 +138,31 @@ class Player(arcade.Sprite): # Not actually the player
curr_bx = 2.0 curr_bx = 2.0
curr_by = 2.0 curr_by = 2.0
lowest = max(enemies, key=lambda e: e.center_y) if enemies else None lowest = max(enemy_formation.enemies, key=lambda e: e.center_y) if enemy_formation.enemies else None
if lowest is not None: if lowest is not None:
lowest_dy = (lowest.center_y - self.center_y) / float(height) lowest_dy = (lowest.center_y - self.center_y) / float(height)
else: else:
lowest_dy = 2.0 lowest_dy = 2.0
enemy_dispersion = 0.0 enemy_dispersion = 0.0
if enemies: if enemy_formation.enemies:
xs = np.array([e.center_x for e in enemies], dtype=np.float32) xs = np.array([e.center_x for e in enemy_formation.enemies], dtype=np.float32)
enemy_dispersion = float(xs.std()) / float(width) enemy_dispersion = float(xs.std()) / float(width)
obs = np.array([player_x_norm, enemy_x, enemy_y, lowest_dy, curr_bx, curr_by, self.player_speed, enemy_count, enemy_dispersion], dtype=np.float32) obs = np.array([
player_x_norm,
enemy_x,
enemy_y,
lowest_dy,
curr_bx,
curr_by,
self.player_speed,
enemy_count,
enemy_dispersion,
time.perf_counter() - self.last_shoot >= PLAYER_ATTACK_SPEED,
player_respawns_norm,
enemy_respawns_norm
], dtype=np.float32)
action, _ = model.predict(obs, deterministic=True) action, _ = model.predict(obs, deterministic=True)
self.prev_x = self.center_x self.prev_x = self.center_x
@@ -71,6 +171,8 @@ class Player(arcade.Sprite): # Not actually the player
elif action == 1: elif action == 1:
self.center_x += PLAYER_SPEED self.center_x += PLAYER_SPEED
elif action == 2: elif action == 2:
pass
elif action == 3:
t = time.perf_counter() t = time.perf_counter()
if t - self.last_shoot >= PLAYER_ATTACK_SPEED: if t - self.last_shoot >= PLAYER_ATTACK_SPEED:
self.last_shoot = t self.last_shoot = t

BIN
invader_agent.zip Normal file

Binary file not shown.

View File

@@ -62,8 +62,8 @@ class Main(arcade.gui.UIView):
self.settings_button.on_click = lambda event: self.settings() self.settings_button.on_click = lambda event: self.settings()
def play(self): def play(self):
from game.play import Game from menus.mode_selector import ModeSelector
self.window.show_view(Game(self.pypresence_client)) self.window.show_view(ModeSelector(self.pypresence_client))
def settings(self): def settings(self):
from menus.settings import Settings from menus.settings import Settings

72
menus/mode_selector.py Normal file
View File

@@ -0,0 +1,72 @@
import arcade, arcade.gui
from utils.preload import button_texture, button_hovered_texture
from utils.constants import dropdown_style, button_style, DIFFICULTY_LEVELS, DIFFICULTY_SETTINGS
class ModeSelector(arcade.gui.UIView):
def __init__(self, pypresence_client):
super().__init__()
self.pypresence_client = pypresence_client
self.pypresence_client.update(state="Selecting Mode")
self.anchor = self.add_widget(arcade.gui.UIAnchorLayout(size_hint=(1, 1)))
self.box = self.anchor.add(arcade.gui.UIBoxLayout(size_hint=(0.75, 0.75), space_between=10), anchor_x="center", anchor_y="center")
self.settings = DIFFICULTY_LEVELS["Easy"]
self.setting_sliders = {}
self.setting_labels = {}
def on_show_view(self):
super().on_show_view()
self.back_button = self.anchor.add(arcade.gui.UITextureButton(texture=button_texture, texture_hovered=button_hovered_texture, text='<--', style=button_style, width=100, height=50), anchor_x="left", anchor_y="top", align_x=5, align_y=-5)
self.back_button.on_click = lambda event: self.main_exit()
self.box.add(arcade.gui.UILabel("Settings", font_size=32))
self.box.add(arcade.gui.UISpace(height=self.window.height / 80))
self.difficulty_selector = self.box.add(arcade.gui.UIDropdown(default="Easy", options=list(DIFFICULTY_LEVELS.keys()), active_style=dropdown_style, primary_style=dropdown_style, dropdown_style=dropdown_style, width=self.window.width / 2, height=self.window.height / 20))
self.difficulty_selector.on_change = lambda event: self.set_difficulty_values(event.new_value)
self.box.add(arcade.gui.UISpace(height=self.window.height / 80))
for key, data in DIFFICULTY_SETTINGS.items():
default, name, min_value, max_value = DIFFICULTY_LEVELS["Easy"][key], *data
label = self.box.add(arcade.gui.UILabel(text=f"{name}: {default}", font_size=14))
slider = self.box.add(arcade.gui.UISlider(value=default, min_value=min_value, max_value=max_value, step=1, width=self.window.width / 2, height=self.window.height / 25))
slider._render_steps = lambda surface: None
slider.on_event = lambda event: None # disable slider for difficulties
slider.on_click = lambda event: None # disable slider for difficulties
slider.on_change = lambda e, key=key: self.change_value(key, e.new_value)
self.setting_sliders[key] = slider
self.setting_labels[key] = label
self.play_button = self.box.add(arcade.gui.UITextureButton(text="Play", width=self.window.width / 2, height=self.window.height / 15, texture=button_texture, texture_hovered=button_hovered_texture, style=button_style))
self.play_button.on_click = lambda event: self.start_game()
def set_difficulty_values(self, difficulty):
for key, value in DIFFICULTY_LEVELS[difficulty].items():
self.settings[key] = value
self.setting_labels[key].text = f"{DIFFICULTY_SETTINGS[key][0]}: {value}"
self.setting_sliders[key].value = value
for slider in self.setting_sliders.values():
if difficulty != "Custom":
slider.on_event = lambda event: None
slider.on_click = lambda event: None
else:
slider.on_event = lambda event, slider=slider: arcade.gui.UISlider.on_event(slider, event)
slider.on_click = lambda event, slider=slider: arcade.gui.UISlider.on_click(slider, event)
def change_value(self, key, value):
self.settings[key] = int(value)
self.setting_labels[key].text = f"{DIFFICULTY_SETTINGS[key][0]}: {int(value)}"
def start_game(self):
from game.play import Game
self.window.show_view(Game(self.pypresence_client, self.settings))

View File

@@ -1,20 +1,26 @@
import arcade, arcade.gui, threading, io, os, time import arcade, arcade.gui, threading, os, queue, time, shutil
import matplotlib.pyplot as plt
import pandas as pd
from utils.constants import button_style, MODEL_SETTINGS, monitor_log_dir from utils.constants import button_style, MODEL_SETTINGS, monitor_log_dir
from utils.preload import button_texture, button_hovered_texture from utils.preload import button_texture, button_hovered_texture
from utils.rl import SpaceInvadersEnv
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from io import BytesIO
from PIL import Image from PIL import Image
from io import BytesIO
from stable_baselines3 import PPO from stable_baselines3 import PPO
from stable_baselines3.common.monitor import Monitor from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.logger import configure from stable_baselines3.common.logger import configure
from stable_baselines3.common.vec_env import DummyVecEnv
from utils.rl import SpaceInvadersEnv def make_env(rank: int, seed: int = 0):
def _init():
env = SpaceInvadersEnv()
env = Monitor(env, filename=os.path.join(monitor_log_dir, f"monitor_{rank}.csv"))
return env
return _init
class TrainModel(arcade.gui.UIView): class TrainModel(arcade.gui.UIView):
def __init__(self, pypresence_client): def __init__(self, pypresence_client):
@@ -24,22 +30,25 @@ class TrainModel(arcade.gui.UIView):
self.pypresence_client.update(state="Model Training") self.pypresence_client.update(state="Model Training")
self.anchor = self.add_widget(arcade.gui.UIAnchorLayout(size_hint=(1, 1))) self.anchor = self.add_widget(arcade.gui.UIAnchorLayout(size_hint=(1, 1)))
self.box = self.anchor.add(arcade.gui.UIBoxLayout(space_between=10)) self.box = self.anchor.add(arcade.gui.UIBoxLayout(space_between=5))
self.settings = { self.settings = {
setting: data[0] # default value setting: data[0]
for setting, data in MODEL_SETTINGS.items() for setting, data in MODEL_SETTINGS.items()
} }
self.labels = {} self.labels = {}
self.training = False self.training = False
self.training_text = "" self.training_text = "Starting training..."
self.result_queue = queue.Queue()
self.training_thread = None
self.last_progress_update = time.perf_counter() self.last_progress_update = time.perf_counter()
def on_show_view(self): def on_show_view(self):
super().on_show_view() super().on_show_view()
self.show_menu() self.show_menu()
def main_exit(self): def main_exit(self):
@@ -50,128 +59,188 @@ class TrainModel(arcade.gui.UIView):
self.back_button = self.anchor.add(arcade.gui.UITextureButton(texture=button_texture, texture_hovered=button_hovered_texture, text='<--', style=button_style, width=100, height=50), anchor_x="left", anchor_y="top", align_x=5, align_y=-5) self.back_button = self.anchor.add(arcade.gui.UITextureButton(texture=button_texture, texture_hovered=button_hovered_texture, text='<--', style=button_style, width=100, height=50), anchor_x="left", anchor_y="top", align_x=5, align_y=-5)
self.back_button.on_click = lambda event: self.main_exit() self.back_button.on_click = lambda event: self.main_exit()
self.box.add(arcade.gui.UILabel("Settings", font_size=36)) self.box.add(arcade.gui.UILabel("Settings", font_size=32))
for setting, data in MODEL_SETTINGS.items(): for setting, data in MODEL_SETTINGS.items():
default, min_value, max_value, step = data default, min_value, max_value, step = data
label = self.box.add(arcade.gui.UILabel(text=f"{setting.replace('_', ' ').capitalize()}: {default}", font_size=18)) is_int = setting == "n_envs" or (abs(step - 1) < 1e-6 and abs(min_value - round(min_value)) < 1e-6)
val_text = str(int(default)) if is_int else str(default)
label = self.box.add(arcade.gui.UILabel(text=f"{setting.replace('_', ' ').capitalize()}: {val_text}", font_size=14))
slider = self.box.add(arcade.gui.UISlider(value=default, min_value=min_value, max_value=max_value, step=step, width=self.window.width / 2, height=self.window.height / 25)) slider = self.box.add(arcade.gui.UISlider(value=default, min_value=min_value, max_value=max_value, step=step, width=self.window.width / 2, height=self.window.height / 25))
slider._render_steps = lambda surface: None slider._render_steps = lambda surface: None
slider.on_change = lambda e, key=setting: self.change_value(key, e.new_value) slider.on_change = lambda e, key=setting, is_int_slider=is_int: self.change_value(key, e.new_value, is_int_slider)
self.labels[setting] = label self.labels[setting] = label
train_button = self.box.add(arcade.gui.UITextureButton(width=self.window.width / 2, height=self.window.height / 10, text="Train", style=button_style, texture=button_texture, texture_hovered=button_hovered_texture)) train_button = self.box.add(arcade.gui.UITextureButton(width=self.window.width / 2, height=self.window.height / 10, text="Train", style=button_style, texture=button_texture, texture_hovered=button_hovered_texture))
train_button.on_click = lambda e: self.start_training() train_button.on_click = lambda e: self.start_training()
def change_value(self, key, value): def change_value(self, key, value, is_int=False):
self.labels[key].text = f"{key.replace('_', ' ').capitalize()}: {self.round_near_int(value)}" if is_int:
self.settings[key] = self.round_near_int(value) val = int(round(value))
self.settings[key] = val
self.labels[key].text = f"{key.replace('_', ' ').capitalize()}: {val}"
else:
val = self.round_near_int(value)
self.settings[key] = val
self.labels[key].text = f"{key.replace('_', ' ').capitalize()}: {val}"
def start_training(self): def start_training(self):
self.box.clear() self.box.clear()
self.training = True self.training_text = "Starting training..."
self.training_label = self.box.add(arcade.gui.UILabel("Starting training...", font_size=16, multiline=True, width=self.window.width / 2, height=self.window.height / 2))
self.training_label = self.box.add(arcade.gui.UILabel("No Output yet.", font_size=16, multiline=True, width=self.window.width / 2, height=self.window.height / 2))
self.plot_image_widget = self.box.add(arcade.gui.UIImage(texture=arcade.Texture.create_empty("empty", (1, 1)))) self.plot_image_widget = self.box.add(arcade.gui.UIImage(texture=arcade.Texture.create_empty("empty", (1, 1))))
self.plot_image_widget.visible = False self.plot_image_widget.visible = False
threading.Thread(target=self.train, daemon=True).start() self.training_thread = threading.Thread(target=self.train, daemon=True)
self.training_thread.start()
def on_update(self, delta_time): def on_update(self, delta_time):
if self.training and os.path.exists(os.path.join("training_logs", "progress.csv")) and time.perf_counter() - self.last_progress_update >= 0.5: try:
self.last_progress_update = time.perf_counter() result = self.result_queue.get_nowait()
try: if result["type"] == "text":
progress_df = pd.read_csv(os.path.join("training_logs", "progress.csv")) self.training_text = result["message"]
except pd.errors.EmptyDataError:
return
progress_text = "" elif result["type"] == "plot":
self.plot_image_widget.texture = result["image"]
self.plot_image_widget.width = result["image"].width
self.plot_image_widget.height = result["image"].height
self.plot_image_widget.trigger_render()
self.plot_image_widget.visible = True
for key, value in progress_df.items(): elif result["type"] == "finished":
progress_text += f"{key}: {round(value.iloc[-1], 6)}\n" self.training = False
self.training_text = "Training finished."
self.training_text = progress_text except queue.Empty:
if self.training and os.path.exists(os.path.join("training_logs", "progress.csv")) and all([os.path.exists(os.path.join(monitor_log_dir, f"monitor_{i}.csv.monitor.csv")) for i in range(int(self.settings["n_envs"]))]) and time.perf_counter() - self.last_progress_update >= 1:
self.last_progress_update = time.perf_counter()
self.plot_results()
if hasattr(self, "training_label"): if hasattr(self, "training_label"):
self.training_label.text = self.training_text self.training_label.text = self.training_text
def round_near_int(self, x, tol=1e-4): def round_near_int(self, x, tol=1e-4):
nearest = round(x) nearest = round(x)
if abs(x - nearest) < tol: if abs(x - nearest) < tol:
return nearest return nearest
return x return x
def train(self): def train(self):
os.makedirs(monitor_log_dir, exist_ok=True) if os.path.exists(monitor_log_dir):
env = Monitor(SpaceInvadersEnv(), filename=os.path.join(monitor_log_dir, "monitor.csv")) shutil.rmtree(monitor_log_dir)
os.makedirs(monitor_log_dir)
n_envs = int(self.settings["n_envs"])
env = DummyVecEnv([make_env(i) for i in range(n_envs)])
n_steps = int(self.settings["n_steps"])
batch_size = int(self.settings["batch_size"])
total_steps_per_rollout = n_steps * max(1, n_envs)
if total_steps_per_rollout % batch_size != 0:
batch_size = max(64, total_steps_per_rollout // max(1, total_steps_per_rollout // batch_size))
print(f"Warning: Adjusting batch size to {batch_size} for {n_envs} envs.")
model = PPO( model = PPO(
"MlpPolicy", "MlpPolicy",
env, env,
n_steps=self.settings["n_steps"], n_steps=n_steps,
batch_size=self.settings["batch_size"], batch_size=batch_size,
n_epochs=self.settings["n_epochs"], n_epochs=int(self.settings["n_epochs"]),
learning_rate=self.settings["learning_rate"], learning_rate=float(self.settings["learning_rate"]),
verbose=1, verbose=1,
device="cpu", device="cpu",
gamma=self.settings["gamma"], gamma=float(self.settings["gamma"]),
ent_coef=self.settings["ent_coef"], ent_coef=float(self.settings["ent_coef"]),
clip_range=self.settings["clip_range"], clip_range=float(self.settings["clip_range"]),
) )
new_logger = configure( new_logger = configure(folder=monitor_log_dir, format_strings=["csv"])
folder=monitor_log_dir, format_strings=["csv"]
)
model.set_logger(new_logger) model.set_logger(new_logger)
model.learn(self.settings["learning_steps"]) try:
model.save("invader_agent") self.training = True
model.learn(int(self.settings["learning_steps"]))
model.save("invader_agent")
except Exception as e:
print(f"Error during training: {e}")
self.result_queue.put({"type": "text", "message": f"Error:\n{e}"})
finally:
try:
env.close()
except Exception:
pass
self.result_queue.put({"type": "finished"})
self.training = False def plot_results(self):
try:
reward_df = pd.read_csv(os.path.join(monitor_log_dir, "progress.csv"))
except pd.errors.EmptyDataError:
return
self.plot_results(os.path.join(monitor_log_dir, "monitor.csv"), os.path.join(monitor_log_dir, "progress.csv")) all_monitor_files = [os.path.join(monitor_log_dir, f) for f in os.listdir(monitor_log_dir) if f.startswith("monitor_") and f.endswith(".csv")]
try:
df_list = [pd.read_csv(f, skiprows=1) for f in all_monitor_files]
except pd.errors.EmptyDataError:
return
def plot_results(self, log_path, loss_log_path): monitor_df = pd.concat(df_list).sort_values(by='t')
df = pd.read_csv(log_path, skiprows=1) monitor_df['total_timesteps'] = monitor_df['l'].cumsum()
loss_log_path = os.path.join(monitor_log_dir, "progress.csv")
loss_df = None
if os.path.exists(loss_log_path):
try:
loss_df = pd.read_csv(loss_log_path)
except Exception:
loss_df = None
fig, axes = plt.subplots(2, 1, figsize=(6, 8), dpi=100) fig, axes = plt.subplots(2, 1, figsize=(6, 8), dpi=100)
loss_df = pd.read_csv(loss_log_path) if monitor_df is not None and 'total_timesteps' in monitor_df.columns and 'r' in monitor_df.columns:
axes[0].plot(monitor_df['total_timesteps'], monitor_df['r'].rolling(window=10).mean(), label='Episodic Reward (Rolling 10)')
axes[0].plot(np.cumsum(df['l']), df['r'], label='Reward') elif reward_df is not None and 'time/total_timesteps' in reward_df.columns and 'rollout/ep_rew_mean' in reward_df.columns:
axes[0].plot(reward_df['time/total_timesteps'], reward_df['rollout/ep_rew_mean'], label='Ep reward mean')
else:
axes[0].text(0.5, 0.5, "No reward data available", horizontalalignment='center', verticalalignment='center')
axes[0].set_title('PPO Training: Episodic Reward') axes[0].set_title('PPO Training: Episodic Reward')
axes[0].set_xlabel('Total Timesteps') axes[0].set_xlabel('Total Timesteps')
axes[0].set_ylabel('Reward') axes[0].set_ylabel('Reward')
axes[0].grid(True) axes[0].grid(True)
axes[1].plot(loss_df['time/total_timesteps'], loss_df['train/policy_gradient_loss'], label='Policy Gradient Loss') axes[1].set_title('PPO Training: Loss & Variance')
axes[1].plot(loss_df['time/total_timesteps'], loss_df['train/value_loss'], label='Value Loss')
axes[1].plot(loss_df['time/total_timesteps'], loss_df['train/explained_variance'], label='Explained Variance')
axes[1].set_title('PPO Training: Loss Functions')
axes[1].set_xlabel('Total Timesteps') axes[1].set_xlabel('Total Timesteps')
axes[1].set_ylabel('Loss Value') axes[1].set_ylabel('Value')
axes[1].legend()
axes[1].grid(True) axes[1].grid(True)
if loss_df is not None and 'time/total_timesteps' in loss_df.columns and 'train/policy_gradient_loss' in loss_df.columns and 'train/value_loss' in loss_df.columns and 'train/explained_variance' in loss_df.columns:
tcol = 'time/total_timesteps'
axes[1].plot(loss_df[tcol], loss_df['train/policy_gradient_loss'], label='Policy Loss')
axes[1].plot(loss_df[tcol], loss_df['train/value_loss'], label='Value Loss')
axes[1].plot(loss_df[tcol], loss_df['train/explained_variance'], label='Explained Variance')
axes[1].legend()
else:
axes[1].text(0.5, 0.5, "No loss/variance data available", horizontalalignment='center', verticalalignment='center')
plt.tight_layout() plt.tight_layout()
buffer = BytesIO() buffer = BytesIO()
plt.savefig(buffer, format='png') plt.savefig(buffer, format='png', bbox_inches='tight')
buffer.seek(0) buffer.seek(0)
plt.close(fig) plt.close(fig)
plot_texture = arcade.Texture(Image.open(buffer)) pil_img = Image.open(buffer).convert("RGBA")
self.plot_image_widget.texture = plot_texture plot_texture = arcade.Texture(pil_img)
self.plot_image_widget.size_hint = (None, None)
self.plot_image_widget.width = plot_texture.width
self.plot_image_widget.height = plot_texture.height
self.plot_image_widget.visible = True self.result_queue.put({"type": "plot", "image": plot_texture})
self.training_text = "Training finished. Plot displayed."

2
run.py
View File

@@ -29,7 +29,7 @@ timestamp = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
log_filename = f"debug_{timestamp}.log" log_filename = f"debug_{timestamp}.log"
logging.basicConfig(filename=f'{os.path.join(log_dir, log_filename)}', format='%(asctime)s %(name)s %(levelname)s: %(message)s', level=logging.DEBUG) logging.basicConfig(filename=f'{os.path.join(log_dir, log_filename)}', format='%(asctime)s %(name)s %(levelname)s: %(message)s', level=logging.DEBUG)
for logger_name_to_disable in ['arcade']: for logger_name_to_disable in ['arcade', "matplotlib", "matplotlib.fontmanager", "PIL"]:
logging.getLogger(logger_name_to_disable).propagate = False logging.getLogger(logger_name_to_disable).propagate = False
logging.getLogger(logger_name_to_disable).disabled = True logging.getLogger(logger_name_to_disable).disabled = True

View File

@@ -3,25 +3,68 @@ from arcade.types import Color
from arcade.gui.widgets.buttons import UITextureButtonStyle, UIFlatButtonStyle from arcade.gui.widgets.buttons import UITextureButtonStyle, UIFlatButtonStyle
from arcade.gui.widgets.slider import UISliderStyle from arcade.gui.widgets.slider import UISliderStyle
ENEMY_ROWS = 3 ENEMY_SPEED = 5
ENEMY_COLS = 13 ENEMY_ATTACK_SPEED = 0.75
PLAYER_SPEED = 5 # not actually player PLAYER_SPEED = 5 # not actually player
PLAYER_ATTACK_SPEED = 0.75 PLAYER_ATTACK_SPEED = 0.75
BULLET_SPEED = 3 BULLET_SPEED = 5
BULLET_RADIUS = 10 BULLET_RADIUS = 15
# default, min, max, step # default, min, max, step
MODEL_SETTINGS = { MODEL_SETTINGS = {
"n_steps": [2048, 256, 8192, 256], "n_steps": [1024, 256, 8192, 256],
"batch_size": [64, 16, 512, 16], "batch_size": [128, 16, 512, 16],
"n_epochs": [10, 1, 50, 1], "n_epochs": [10, 1, 50, 1],
"learning_rate": [3e-4, 1e-5, 1e-2, 1e-5], "learning_rate": [3e-4, 1e-5, 1e-2, 1e-5],
"gamma": [0.99, 0.8, 0.9999, 0.001], "gamma": [0.99, 0.8, 0.9999, 0.001],
"ent_coef": [0.01, 0.0, 0.1, 0.001], "ent_coef": [0.015, 0.0, 0.1, 0.001],
"clip_range": [0.2, 0.1, 0.4, 0.01], "clip_range": [0.2, 0.1, 0.4, 0.01],
"learning_steps": [500_000, 50_000, 25_000_000, 50_000] "learning_steps": [1_000_000, 50_000, 25_000_000, 50_000],
"n_envs": (12, 1, 128, 1)
}
DIFFICULTY_SETTINGS = {
"enemy_rows": ["Enemy Rows", 1, 6],
"enemy_cols": ["Enemy Columns", 1, 7],
"enemy_respawns": ["Enemy Respawns", 1, 5],
"player_count": ["Player Count", 1, 10],
"player_respawns": ["Player Respawns", 1, 5]
}
DIFFICULTY_LEVELS = {
"Easy": {
"enemy_rows": 3,
"enemy_cols": 4,
"enemy_respawns": 5,
"player_count": 2,
"player_respawns": 2
},
"Medium": {
"enemy_rows": 3,
"enemy_cols": 5,
"enemy_respawns": 4,
"player_count": 4,
"player_respawns": 3
},
"Hard": {
"enemy_rows": 4,
"enemy_cols": 6,
"enemy_respawns": 3,
"player_count": 6,
"player_respawns": 4
},
"Extra Hard": {
"enemy_rows": 6,
"enemy_cols": 7,
"enemy_respawns": 2,
"player_count": 8,
"player_respawns": 5
},
"Custom": {
}
} }
menu_background_color = (30, 30, 47) menu_background_color = (30, 30, 47)

View File

@@ -1,78 +1,87 @@
import gymnasium as gym import gymnasium as gym
import numpy as np import numpy as np
import arcade import arcade
import time
import random import random
from game.sprites import Enemy, Player, Bullet from game.sprites import EnemyFormation, Player, Bullet
from utils.constants import PLAYER_SPEED, BULLET_SPEED, PLAYER_ATTACK_SPEED, ENEMY_ROWS, ENEMY_COLS from utils.constants import PLAYER_SPEED, BULLET_SPEED, ENEMY_SPEED, DIFFICULTY_LEVELS
class SpaceInvadersEnv(gym.Env): class SpaceInvadersEnv(gym.Env):
def __init__(self, width=800, height=600): def __init__(self, width=800, height=600, difficulty="Hard"):
self.width = width self.width = width
self.height = height self.height = height
self.action_space = gym.spaces.Discrete(3) self.action_space = gym.spaces.Discrete(4)
self.observation_space = gym.spaces.Box(low=-2.0, high=2.0, shape=(9,), dtype=np.float32) self.observation_space = gym.spaces.Box(low=-2.0, high=2.0, shape=(12,), dtype=np.float32)
if difficulty not in DIFFICULTY_LEVELS:
raise ValueError(f"Unknown difficulty: {difficulty}. Available: {list(DIFFICULTY_LEVELS.keys())}")
self.difficulty_settings = DIFFICULTY_LEVELS[difficulty]
self.enemies = []
self.bullets = [] self.bullets = []
self.dir_history = []
self.last_shot = 0.0
self.player = None self.player = None
self.prev_x = 0.0 self.enemy_formation = None
self.player_speed = 0.0 self.player_speed = 0.0
self.prev_bx = 2.0 self.max_steps = 2000
self.steps_since_direction_change = 0
self.last_direction = 0
self.max_steps = 1000
self.current_step = 0 self.current_step = 0
self.enemies_killed = 0
self.enemy_move_speed = ENEMY_SPEED
self.player_respawns = self.difficulty_settings["player_respawns"]
self.enemy_respawns = self.difficulty_settings["enemy_respawns"]
self.player_respawns_remaining = 0
self.enemy_respawns_remaining = 0
self.player_alive = True
self.player_attack_cooldown_steps = 5
self.current_cooldown = 0
def reset(self, seed=None, options=None): def reset(self, seed=None, options=None):
if seed is not None: if seed is not None:
np.random.seed(seed) np.random.seed(seed)
random.seed(seed) random.seed(seed)
self.enemies = []
self.bullets = [] self.bullets = []
self.dir_history = []
self.player = Player(self.width / 2 + random.randint(int(-self.width / 3), int(self.width / 3)), 100) self.player = Player(self.width / 2 + random.randint(int(-self.width / 3), int(self.width / 3)), 100)
self.prev_x = self.player.center_x
self.player_speed = 0.0 self.player_speed = 0.0
self.prev_bx = 2.0
self.steps_since_direction_change = 0
self.last_direction = 0
self.current_step = 0 self.current_step = 0
self.enemies_killed = 0
self.player_respawns_remaining = self.player_respawns
self.enemy_respawns_remaining = self.enemy_respawns
self.player_alive = True
self.current_cooldown = 0
start_x = self.width * 0.15 start_x = self.width * 0.15
start_y = self.height * 0.9 start_y = self.height * 0.9
for r in range(ENEMY_ROWS): self.enemy_formation = EnemyFormation(start_x, start_y, None,
for c in range(ENEMY_COLS): self.difficulty_settings["enemy_rows"],
e = Enemy(start_x + c * 100, start_y - r * 100) self.difficulty_settings["enemy_cols"])
self.enemies.append(e)
self.last_shot = time.perf_counter()
return self._obs(), {} return self._obs(), {}
def _nearest_enemy(self): def _nearest_enemy(self):
if not self.enemies: if not self.enemy_formation.enemies:
return None return None
return min(self.enemies, key=lambda e: abs(e.center_x - self.player.center_x))
return min(self.enemy_formation.enemies, key=lambda e: abs(e.center_x - self.player.center_x))
def _lowest_enemy(self): def _lowest_enemy(self):
if not self.enemies: if not self.enemy_formation.enemies:
return None return None
return max(self.enemies, key=lambda e: e.center_y)
return min(self.enemy_formation.enemies, key=lambda e: e.center_y)
def _nearest_enemy_bullet(self): def _nearest_enemy_bullet(self):
enemy_bullets = [b for b in self.bullets if b.direction_y == -1] enemy_bullets = [b for b in self.bullets if b.direction_y == -1]
if not enemy_bullets: if not enemy_bullets:
return None return None
return min(enemy_bullets, key=lambda b: abs(b.center_x - self.player.center_x) + abs(b.center_y - self.player.center_y)) return min(enemy_bullets, key=lambda b: abs(b.center_x - self.player.center_x) + abs(b.center_y - self.player.center_y))
def _obs(self): def _obs(self):
if self.enemies: if self.enemy_formation.enemies and self.player_alive:
nearest = self._nearest_enemy() nearest = self._nearest_enemy()
enemy_x = (nearest.center_x - self.player.center_x) / float(self.width) enemy_x = (nearest.center_x - self.player.center_x) / float(self.width)
enemy_y = (nearest.center_y - self.player.center_y) / float(self.height) enemy_y = (nearest.center_y - self.player.center_y) / float(self.height)
@@ -81,31 +90,60 @@ class SpaceInvadersEnv(gym.Env):
enemy_y = 2.0 enemy_y = 2.0
lowest = self._lowest_enemy() lowest = self._lowest_enemy()
if lowest is not None and self.player_alive:
if lowest is not None:
lowest_dy = (lowest.center_y - self.player.center_y) / float(self.height) lowest_dy = (lowest.center_y - self.player.center_y) / float(self.height)
else: else:
lowest_dy = 2.0 lowest_dy = 2.0
nb = self._nearest_enemy_bullet() nb = self._nearest_enemy_bullet()
if nb is not None: if nb is not None and self.player_alive:
bx = (nb.center_x - self.player.center_x) / float(self.width) bx = (nb.center_x - self.player.center_x) / float(self.width)
by = (nb.center_y - self.player.center_y) / float(self.height) by = (nb.center_y - self.player.center_y) / float(self.height)
else: else:
bx = 2.0 bx = 2.0
by = 2.0 by = 2.0
enemy_count = len(self.enemies) / float(max(1, ENEMY_ROWS * ENEMY_COLS)) enemy_count = len(self.enemy_formation.enemies) / float(max(1, self.difficulty_settings["enemy_rows"] * self.difficulty_settings["enemy_cols"]))
player_x_norm = self.player.center_x / float(self.width) player_x_norm = self.player.center_x / float(self.width) if self.player_alive else 0.5
enemy_dispersion = 0.0
if self.enemies: enemy_dispersion = 0.0
xs = np.array([e.center_x for e in self.enemies], dtype=np.float32) if self.enemy_formation.enemies:
xs = np.array([e.center_x for e in self.enemy_formation.enemies], dtype=np.float32)
enemy_dispersion = float(xs.std()) / float(self.width) enemy_dispersion = float(xs.std()) / float(self.width)
obs = np.array([player_x_norm, enemy_x, enemy_y, lowest_dy, bx, by, self.player_speed, enemy_count, enemy_dispersion], dtype=np.float32) can_shoot = 1.0 if (self.player_alive and self.current_cooldown <= 0) else 0.0
player_respawns_norm = self.player_respawns_remaining / float(max(1, self.player_respawns))
enemy_respawns_norm = self.enemy_respawns_remaining / float(max(1, self.enemy_respawns))
obs = np.array([
player_x_norm,
enemy_x,
enemy_y,
lowest_dy,
bx,
by,
self.player_speed,
enemy_count,
enemy_dispersion,
can_shoot,
player_respawns_norm,
enemy_respawns_norm
], dtype=np.float32)
return obs return obs
def _respawn_player(self):
self.player = Player(self.width / 2 + random.randint(int(-self.width / 3), int(self.width / 3)), 100)
self.player_alive = True
self.bullets = [b for b in self.bullets if b.direction_y == 1]
self.current_cooldown = 0
def _respawn_enemies(self):
self.enemy_formation.start_x = self.width * 0.15
self.enemy_formation.start_y = self.height * 0.9
self.enemy_formation.create_formation()
def step(self, action): def step(self, action):
reward = 0.0 reward = 0.0
terminated = False terminated = False
@@ -115,113 +153,125 @@ class SpaceInvadersEnv(gym.Env):
if self.current_step >= self.max_steps: if self.current_step >= self.max_steps:
truncated = True truncated = True
nearest = self._nearest_enemy() if self.current_cooldown > 0:
if nearest is not None: self.current_cooldown -= 1
enemy_x = (nearest.center_x - self.player.center_x) / float(self.width)
else:
enemy_x = 2.0
prev_x = self.player.center_x if self.player_alive:
current_action_dir = 0 prev_x = self.player.center_x
if action == 0: if action == 0:
self.player.center_x -= PLAYER_SPEED self.player.center_x -= PLAYER_SPEED
current_action_dir = -1 elif action == 1:
elif action == 1: self.player.center_x += PLAYER_SPEED
self.player.center_x += PLAYER_SPEED elif action == 2:
current_action_dir = 1 pass
elif action == 2: elif action == 3:
t = time.perf_counter() if self.current_cooldown <= 0:
if t - self.last_shot >= PLAYER_ATTACK_SPEED: self.current_cooldown = self.player_attack_cooldown_steps
self.last_shot = t reward += 0.01
b = Bullet(self.player.center_x, self.player.center_y, 1)
b = Bullet(self.player.center_x, self.player.center_y, 1) self.bullets.append(b)
else:
reward -= 0.05
self.bullets.append(b) if self.enemy_formation.enemies:
nearest = self._nearest_enemy()
if enemy_x != 2.0 and abs(enemy_x) < 0.04: alignment = abs(nearest.center_x - self.player.center_x) / self.width
if alignment < 0.025:
reward += 0.3 reward += 0.3
elif enemy_x != 2.0 and abs(enemy_x) < 0.1:
reward += 0.1
if self.player.center_x > self.width: self.player.center_x = np.clip(self.player.center_x, 0, self.width)
self.player.center_x = self.width self.player_speed = (self.player.center_x - prev_x) / max(1e-6, PLAYER_SPEED)
elif self.player.center_x < 0:
self.player.center_x = 0
self.player_speed = (self.player.center_x - prev_x) / max(1e-6, PLAYER_SPEED) if self.enemy_formation.enemies and self.player_alive:
if self.enemy_formation.center_x < self.player.center_x:
self.enemy_formation.move(self.width, self.height, "x", self.enemy_move_speed)
elif self.enemy_formation.center_x > self.player.center_x:
self.enemy_formation.move(self.width, self.height, "x", -self.enemy_move_speed)
if random.random() < 0.02:
if random.random() < 0.5:
self.enemy_formation.move(self.width, self.height, "y", -self.enemy_move_speed)
else:
self.enemy_formation.move(self.width, self.height, "y", self.enemy_move_speed)
if current_action_dir != 0: bullets_to_remove = []
if self.last_direction != 0 and current_action_dir != self.last_direction:
if self.steps_since_direction_change < 3:
reward -= 0.1
self.steps_since_direction_change = 0 for b in self.bullets:
else:
self.steps_since_direction_change += 1
self.last_direction = current_action_dir
if enemy_x != 2.0:
if abs(enemy_x) < 0.03:
reward += 0.1
elif abs(enemy_x) < 0.08:
reward += 0.05
for b in list(self.bullets):
b.center_y += b.direction_y * BULLET_SPEED b.center_y += b.direction_y * BULLET_SPEED
if b.center_y > self.height or b.center_y < 0: if b.center_y > self.height or b.center_y < 0:
try: bullets_to_remove.append(b)
self.bullets.remove(b) continue
except ValueError:
pass
for b in list(self.bullets):
if b.direction_y == 1: if b.direction_y == 1:
for e in list(self.enemies): for e in self.enemy_formation.enemies:
if arcade.check_for_collision(b, e): if arcade.check_for_collision(b, e):
try: self.enemy_formation.remove_enemy(e)
self.enemies.remove(e) bullets_to_remove.append(b)
except ValueError: reward += 10.0
pass self.enemies_killed += 1
try:
self.bullets.remove(b)
except ValueError:
pass
reward += 1.0
break break
for b in list(self.bullets): elif b.direction_y == -1 and self.player_alive:
if b.direction_y == -1:
if arcade.check_for_collision(b, self.player): if arcade.check_for_collision(b, self.player):
try: bullets_to_remove.append(b)
self.bullets.remove(b) reward -= 10.0
except ValueError: self.player_alive = False
pass
reward -= 5.0 if self.player_respawns_remaining > 0:
self.player_respawns_remaining -= 1
self._respawn_player()
reward += 2.0
else:
terminated = True
for b in bullets_to_remove:
if b in self.bullets:
self.bullets.remove(b)
if self.player_alive:
lowest_enemy = self._lowest_enemy()
if lowest_enemy and lowest_enemy.center_y <= self.player.center_y:
reward -= 10.0
self.player_alive = False
if self.player_respawns_remaining > 0:
self.player_respawns_remaining -= 1
self._respawn_player()
else:
terminated = True terminated = True
if not self.enemies: if not self.enemy_formation.enemies:
reward += 10.0 reward += 50.0
terminated = True
if self.enemy_respawns_remaining > 0:
self.enemy_respawns_remaining -= 1
self._respawn_enemies()
reward += 20.0
else:
reward += 100.0
terminated = True
if self.enemies and random.random() < 0.05: shooting_prob = 0.05 + (0.05 * (1.0 - len(self.enemy_formation.enemies) / (self.difficulty_settings["enemy_rows"] * self.difficulty_settings["enemy_cols"])))
e = random.choice(self.enemies) if self.enemy_formation.enemies and random.random() < shooting_prob:
b = Bullet(e.center_x, e.center_y, -1) enemy = self.enemy_formation.get_lowest_enemy()
self.bullets.append(b) if enemy:
b = Bullet(enemy.center_x, enemy.center_y, -1)
self.bullets.append(b)
curr_bullet = self._nearest_enemy_bullet() if self.player_alive:
if curr_bullet is not None: edge_threshold = self.width * 0.15
curr_bx = (curr_bullet.center_x - self.player.center_x) / float(self.width) if self.player.center_x < edge_threshold or self.player.center_x > self.width - edge_threshold:
else: reward -= 0.03
curr_bx = 2.0
if self.prev_bx != 2.0 and curr_bx != 2.0: reward -= 0.005
if abs(curr_bx) > abs(self.prev_bx):
reward += 0.02
reward -= 0.01
obs = self._obs() obs = self._obs()
self.prev_bx = curr_bx
return obs, float(reward), bool(terminated), bool(truncated), {
return obs, float(reward), bool(terminated), bool(truncated), {} "enemies_killed": self.enemies_killed,
"step": self.current_step,
"player_respawns_remaining": self.player_respawns_remaining,
"enemy_respawns_remaining": self.enemy_respawns_remaining
}