mirror of
https://github.com/csd4ni3l/fleet-commander.git
synced 2026-01-01 04:23:47 +01:00
Add 10 million timestep model, improve README, add diffculty/mode selector, score, make model training have instant graphs and also multiple envs for faster training, better plotting, improve RL model by including multiple players, better reard system, use EnemyFormation instead of single Enemy-es
This commit is contained in:
@@ -2,7 +2,9 @@ Fleet Commander is like Space Invaders but you are the enemy instead of the play
|
|||||||
|
|
||||||
It uses AI (Reinforcement Learning) for the Player, and you, the Enemy has to defeat it.
|
It uses AI (Reinforcement Learning) for the Player, and you, the Enemy has to defeat it.
|
||||||
|
|
||||||
You can train yourself, or use the default model which comes with the game.
|
I know the game is too easy and is too simple, but please understand that doing RL isnt the easiest thing ever. I also did this so late so yeah.
|
||||||
|
|
||||||
|
You can train yourself, or use the default model(10 million timesteps) which comes with the game.
|
||||||
|
|
||||||
# Install steps:
|
# Install steps:
|
||||||
|
|
||||||
@@ -19,3 +21,6 @@ You can train yourself, or use the default model which comes with the game.
|
|||||||
- `pip3 install torch --index-url https://download.pytorch.org/whl/cpu`
|
- `pip3 install torch --index-url https://download.pytorch.org/whl/cpu`
|
||||||
- `pip3 install stable_baselines3`
|
- `pip3 install stable_baselines3`
|
||||||
- `python3 run.py`
|
- `python3 run.py`
|
||||||
|
|
||||||
|
# Disclaimer
|
||||||
|
AI assistance was used in this project, since i never did any RL work before. But every instance of AI code was heavily modified by me.
|
||||||
114
game/play.py
114
game/play.py
@@ -1,16 +1,17 @@
|
|||||||
import arcade, arcade.gui, random, time
|
import arcade, arcade.gui, random, time
|
||||||
|
|
||||||
from utils.constants import button_style, ENEMY_ROWS, ENEMY_COLS, PLAYER_ATTACK_SPEED
|
from utils.constants import button_style, ENEMY_ATTACK_SPEED, ENEMY_SPEED
|
||||||
from utils.preload import button_texture, button_hovered_texture
|
from utils.preload import button_texture, button_hovered_texture
|
||||||
|
|
||||||
from stable_baselines3 import PPO
|
from stable_baselines3 import PPO
|
||||||
|
|
||||||
from game.sprites import Enemy, Player, Bullet
|
from game.sprites import EnemyFormation, Player, Bullet
|
||||||
|
|
||||||
class Game(arcade.gui.UIView):
|
class Game(arcade.gui.UIView):
|
||||||
def __init__(self, pypresence_client):
|
def __init__(self, pypresence_client, settings):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
self.settings = settings
|
||||||
self.pypresence_client = pypresence_client
|
self.pypresence_client = pypresence_client
|
||||||
self.pypresence_client.update(state="Invading Space")
|
self.pypresence_client.update(state="Invading Space")
|
||||||
|
|
||||||
@@ -18,42 +19,53 @@ class Game(arcade.gui.UIView):
|
|||||||
|
|
||||||
self.spritelist = arcade.SpriteList()
|
self.spritelist = arcade.SpriteList()
|
||||||
|
|
||||||
self.player = Player(self.window.width / 2 + random.randint(int(-self.window.width / 3), int(self.window.width / 3)), 100) # not actually player
|
self.players = []
|
||||||
self.spritelist.append(self.player)
|
for _ in range(settings["player_count"]):
|
||||||
|
self.players.append(Player(self.window.width / 2 + random.randint(int(-self.window.width / 3), int(self.window.width / 3)), 100)) # not actually player
|
||||||
self.last_player_shoot = time.perf_counter() # not actually player
|
self.spritelist.append(self.players[-1])
|
||||||
|
|
||||||
self.model = PPO.load("invader_agent.zip")
|
self.model = PPO.load("invader_agent.zip")
|
||||||
|
|
||||||
self.enemies: list[Enemy] = []
|
self.enemy_formation = EnemyFormation(self.window.width / 2 + random.randint(int(-self.window.width / 3), int(self.window.width / 3)), self.window.height * 0.9, self.spritelist, settings["enemy_rows"], settings["enemy_cols"])
|
||||||
self.player_bullets: list[Bullet] = []
|
self.player_bullets: list[Bullet] = []
|
||||||
self.enemy_bullets: list[Bullet] = []
|
self.enemy_bullets: list[Bullet] = []
|
||||||
|
|
||||||
self.summon_enemies()
|
self.player_respawns = settings["player_respawns"]
|
||||||
|
self.enemy_respawns = settings["enemy_respawns"]
|
||||||
|
|
||||||
|
self.score = 0
|
||||||
|
|
||||||
|
self.last_enemy_shoot = time.perf_counter()
|
||||||
|
|
||||||
|
self.game_over = False
|
||||||
|
|
||||||
def on_show_view(self):
|
def on_show_view(self):
|
||||||
super().on_show_view()
|
super().on_show_view()
|
||||||
|
|
||||||
self.back_button = self.anchor.add(arcade.gui.UITextureButton(texture=button_texture, texture_hovered=button_hovered_texture, text='<--', style=button_style, width=100, height=50), anchor_x="left", anchor_y="top", align_x=5, align_y=-5)
|
self.back_button = self.anchor.add(arcade.gui.UITextureButton(texture=button_texture, texture_hovered=button_hovered_texture, text='<--', style=button_style, width=100, height=50), anchor_x="left", anchor_y="top", align_x=5, align_y=-5)
|
||||||
self.back_button.on_click = lambda event: self.main_exit()
|
self.back_button.on_click = lambda event: self.main_exit()
|
||||||
|
self.score_label = self.anchor.add(arcade.gui.UILabel("Score: 0", font_size=24), anchor_x="center", anchor_y="top")
|
||||||
|
|
||||||
def main_exit(self):
|
def main_exit(self):
|
||||||
from menus.main import Main
|
from menus.main import Main
|
||||||
self.window.show_view(Main(self.pypresence_client))
|
self.window.show_view(Main(self.pypresence_client))
|
||||||
|
|
||||||
def summon_enemies(self):
|
|
||||||
enemy_start_x = self.window.width * 0.15
|
|
||||||
enemy_start_y = self.window.height * 0.9
|
|
||||||
|
|
||||||
for row in range(ENEMY_ROWS):
|
|
||||||
for col in range(ENEMY_COLS):
|
|
||||||
enemy_sprite = Enemy(enemy_start_x + col * 100, enemy_start_y - row * 100)
|
|
||||||
self.spritelist.append(enemy_sprite)
|
|
||||||
self.enemies.append(enemy_sprite)
|
|
||||||
|
|
||||||
def on_update(self, delta_time):
|
def on_update(self, delta_time):
|
||||||
for enemy in self.enemies:
|
if self.game_over:
|
||||||
enemy.update()
|
return
|
||||||
|
|
||||||
|
if self.window.keyboard[arcade.key.LEFT] or self.window.keyboard[arcade.key.A]:
|
||||||
|
self.enemy_formation.move(self.window.width, self.window.height, "x", -ENEMY_SPEED)
|
||||||
|
if self.window.keyboard[arcade.key.RIGHT] or self.window.keyboard[arcade.key.D]:
|
||||||
|
self.enemy_formation.move(self.window.width, self.window.height, "x", ENEMY_SPEED)
|
||||||
|
if self.window.keyboard[arcade.key.DOWN] or self.window.keyboard[arcade.key.S]:
|
||||||
|
self.enemy_formation.move(self.window.width, self.window.height, "y", -ENEMY_SPEED)
|
||||||
|
if self.window.keyboard[arcade.key.UP] or self.window.keyboard[arcade.key.W]:
|
||||||
|
self.enemy_formation.move(self.window.width, self.window.height, "y", ENEMY_SPEED)
|
||||||
|
if self.enemy_formation.enemies and self.window.keyboard[arcade.key.SPACE] and time.perf_counter() - self.last_enemy_shoot >= ENEMY_ATTACK_SPEED:
|
||||||
|
self.last_enemy_shoot = time.perf_counter()
|
||||||
|
enemy = self.enemy_formation.get_lowest_enemy()
|
||||||
|
self.shoot(enemy.center_x, enemy.center_y, -1)
|
||||||
|
|
||||||
bullets_to_remove = []
|
bullets_to_remove = []
|
||||||
|
|
||||||
@@ -62,18 +74,21 @@ class Game(arcade.gui.UIView):
|
|||||||
|
|
||||||
bullet_hit = False
|
bullet_hit = False
|
||||||
if bullet.direction_y == 1:
|
if bullet.direction_y == 1:
|
||||||
for enemy in self.enemies:
|
for enemy in self.enemy_formation.enemies:
|
||||||
if bullet.rect.intersection(enemy.rect):
|
if bullet.rect.intersection(enemy.rect):
|
||||||
self.spritelist.remove(enemy)
|
self.enemy_formation.remove_enemy(enemy)
|
||||||
self.enemies.remove(enemy)
|
|
||||||
self.player.target = None
|
|
||||||
bullets_to_remove.append(bullet)
|
bullets_to_remove.append(bullet)
|
||||||
bullet_hit = True
|
bullet_hit = True
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
if bullet.rect.intersection(self.player.rect):
|
for player in self.players:
|
||||||
|
if bullet.rect.intersection(player.rect):
|
||||||
|
self.spritelist.remove(player)
|
||||||
|
self.players.remove(player)
|
||||||
bullets_to_remove.append(bullet)
|
bullets_to_remove.append(bullet)
|
||||||
bullet_hit = True
|
bullet_hit = True
|
||||||
|
self.score += 75
|
||||||
|
break
|
||||||
|
|
||||||
if not bullet_hit and bullet.center_y > self.window.height or bullet.center_y < 0:
|
if not bullet_hit and bullet.center_y > self.window.height or bullet.center_y < 0:
|
||||||
bullets_to_remove.append(bullet)
|
bullets_to_remove.append(bullet)
|
||||||
@@ -86,20 +101,40 @@ class Game(arcade.gui.UIView):
|
|||||||
elif bullet_to_remove in self.player_bullets:
|
elif bullet_to_remove in self.player_bullets:
|
||||||
self.player_bullets.remove(bullet_to_remove)
|
self.player_bullets.remove(bullet_to_remove)
|
||||||
|
|
||||||
self.player.update(self.model, self.enemies, self.enemy_bullets, self.window.width, self.window.height) # not actually player
|
for player in self.players:
|
||||||
|
player.update(self.model, self.enemy_formation, self.enemy_bullets, self.window.width, self.window.height, self.player_respawns / self.settings["player_respawns"], self.enemy_respawns / self.settings["enemy_respawns"]) # not actually player
|
||||||
|
|
||||||
if self.player.center_x > self.window.width:
|
if player.center_x > self.window.width:
|
||||||
self.player.center_x = self.window.width
|
player.center_x = self.window.width
|
||||||
elif self.player.center_x < 0:
|
elif player.center_x < 0:
|
||||||
self.player.center_x = 0
|
player.center_x = 0
|
||||||
|
|
||||||
if self.player.shoot:
|
if player.shoot:
|
||||||
self.player.shoot = False
|
player.shoot = False
|
||||||
self.shoot(self.player.center_x, self.player.center_y, 1)
|
self.shoot(player.center_x, player.center_y, 1)
|
||||||
|
|
||||||
if time.perf_counter() - self.last_player_shoot >= PLAYER_ATTACK_SPEED:
|
if len(self.players) == 0:
|
||||||
self.last_player_shoot = time.perf_counter()
|
if self.player_respawns > 0:
|
||||||
self.shoot(self.player.center_x, self.player.center_y, 1)
|
self.player_respawns -= 1
|
||||||
|
for _ in range(self.settings["player_count"]):
|
||||||
|
self.players.append(Player(self.window.width / 2 + random.randint(int(-self.window.width / 3), int(self.window.width / 3)), 100)) # not actually player
|
||||||
|
self.spritelist.append(self.players[-1])
|
||||||
|
self.score += 300
|
||||||
|
else:
|
||||||
|
self.game_over = True
|
||||||
|
self.game_over_label = self.anchor.add(arcade.gui.UILabel("You (The enemies) won!", font_size=48), anchor_x="center", anchor_y="center")
|
||||||
|
|
||||||
|
elif len(self.enemy_formation.enemies) == 0:
|
||||||
|
if self.enemy_respawns > 0:
|
||||||
|
self.enemy_respawns -= 1
|
||||||
|
self.enemy_formation.create_formation(self.window.width / 2 + random.randint(int(-self.window.width / 3), int(self.window.width / 3)), self.window.height * 0.9)
|
||||||
|
else:
|
||||||
|
self.game_over = True
|
||||||
|
self.game_over_label = self.anchor.add(arcade.gui.UILabel("You lost! The Players win!", font_size=48), anchor_x="center", anchor_y="center")
|
||||||
|
|
||||||
|
self.score += 5 * delta_time
|
||||||
|
|
||||||
|
self.score_label.text = f"Score: {int(self.score)}"
|
||||||
|
|
||||||
def shoot(self, x, y, direction_y):
|
def shoot(self, x, y, direction_y):
|
||||||
bullet = Bullet(x, y, direction_y)
|
bullet = Bullet(x, y, direction_y)
|
||||||
@@ -112,11 +147,6 @@ class Game(arcade.gui.UIView):
|
|||||||
|
|
||||||
bullets.append(bullet)
|
bullets.append(bullet)
|
||||||
|
|
||||||
def on_key_press(self, symbol, modifiers):
|
|
||||||
if symbol == arcade.key.SPACE:
|
|
||||||
enemy = random.choice(self.enemies)
|
|
||||||
self.shoot(enemy.center_x, enemy.center_y, -1)
|
|
||||||
|
|
||||||
def on_draw(self):
|
def on_draw(self):
|
||||||
super().on_draw()
|
super().on_draw()
|
||||||
|
|
||||||
|
|||||||
136
game/sprites.py
136
game/sprites.py
@@ -1,10 +1,8 @@
|
|||||||
import arcade, time
|
import arcade, time, random, numpy as np
|
||||||
|
|
||||||
from stable_baselines3 import PPO
|
from stable_baselines3 import PPO
|
||||||
|
|
||||||
import numpy as np
|
from utils.constants import PLAYER_SPEED, BULLET_SPEED, BULLET_RADIUS, PLAYER_ATTACK_SPEED
|
||||||
|
|
||||||
from utils.constants import PLAYER_SPEED, BULLET_SPEED, BULLET_RADIUS, PLAYER_ATTACK_SPEED, ENEMY_COLS, ENEMY_ROWS
|
|
||||||
from utils.preload import player_texture, enemy_texture
|
from utils.preload import player_texture, enemy_texture
|
||||||
|
|
||||||
class Bullet(arcade.Sprite):
|
class Bullet(arcade.Sprite):
|
||||||
@@ -16,9 +14,100 @@ class Bullet(arcade.Sprite):
|
|||||||
def update(self):
|
def update(self):
|
||||||
self.center_y += self.direction_y * BULLET_SPEED
|
self.center_y += self.direction_y * BULLET_SPEED
|
||||||
|
|
||||||
class Enemy(arcade.Sprite):
|
class EnemyFormation():
|
||||||
def __init__(self, x, y):
|
def __init__(self, start_x, start_y, spritelist: arcade.SpriteList | None, rows, cols):
|
||||||
super().__init__(enemy_texture, center_x=x, center_y=y)
|
self.grid = [[] for _ in range(rows)]
|
||||||
|
self.start_x = start_x
|
||||||
|
self.start_y = start_y
|
||||||
|
self.rows = rows
|
||||||
|
self.cols = cols
|
||||||
|
self.spritelist = spritelist
|
||||||
|
|
||||||
|
self.create_formation()
|
||||||
|
|
||||||
|
def create_formation(self, start_x=None, start_y=None):
|
||||||
|
if start_x:
|
||||||
|
self.start_x = start_x
|
||||||
|
if start_y:
|
||||||
|
self.start_y = start_y
|
||||||
|
|
||||||
|
del self.grid
|
||||||
|
self.grid = [[] for _ in range(self.rows)]
|
||||||
|
|
||||||
|
for row in range(self.rows):
|
||||||
|
for col in range(self.cols):
|
||||||
|
enemy_sprite = arcade.Sprite(enemy_texture, center_x=self.start_x + col * 100, center_y=self.start_y - row * 100)
|
||||||
|
|
||||||
|
if self.spritelist:
|
||||||
|
self.spritelist.append(enemy_sprite)
|
||||||
|
|
||||||
|
self.grid[row].append(enemy_sprite)
|
||||||
|
|
||||||
|
def remove_enemy(self, enemy):
|
||||||
|
if self.spritelist and enemy not in self.spritelist:
|
||||||
|
return
|
||||||
|
|
||||||
|
for row in range(self.rows):
|
||||||
|
for col in range(self.cols):
|
||||||
|
if self.grid[row][col] == enemy:
|
||||||
|
self.grid[row][col] = None
|
||||||
|
if self.spritelist:
|
||||||
|
self.spritelist.remove(enemy)
|
||||||
|
return
|
||||||
|
|
||||||
|
def get_lowest_enemy(self):
|
||||||
|
valid_cols = []
|
||||||
|
|
||||||
|
for col in range(self.cols):
|
||||||
|
row = self.rows - 1
|
||||||
|
|
||||||
|
while row >= 0 and self.grid[row][col] is None:
|
||||||
|
row -= 1
|
||||||
|
|
||||||
|
if row >= 0:
|
||||||
|
valid_cols.append((col, row))
|
||||||
|
|
||||||
|
if not valid_cols:
|
||||||
|
return None
|
||||||
|
|
||||||
|
col, row = random.choice(valid_cols)
|
||||||
|
return self.grid[row][col]
|
||||||
|
|
||||||
|
def move(self, width, height, direction_type, value):
|
||||||
|
if direction_type == "x":
|
||||||
|
wall_hit = False
|
||||||
|
for enemy in self.enemies:
|
||||||
|
self.start_x += value
|
||||||
|
enemy.center_x += value
|
||||||
|
|
||||||
|
if enemy.center_x + enemy.width / 2 > width or enemy.center_x < enemy.width / 2:
|
||||||
|
wall_hit = True
|
||||||
|
|
||||||
|
if wall_hit:
|
||||||
|
for enemy in self.enemies:
|
||||||
|
self.start_x -= value
|
||||||
|
enemy.center_x -= value
|
||||||
|
else:
|
||||||
|
wall_hit = False
|
||||||
|
for enemy in self.enemies:
|
||||||
|
self.start_x += value
|
||||||
|
enemy.center_y += value
|
||||||
|
|
||||||
|
if enemy.center_y + enemy.height / 2 > height or enemy.center_y < enemy.height / 2:
|
||||||
|
wall_hit = True
|
||||||
|
|
||||||
|
if wall_hit:
|
||||||
|
for enemy in self.enemies:
|
||||||
|
self.start_y -= value
|
||||||
|
enemy.center_y -= value
|
||||||
|
|
||||||
|
@property
|
||||||
|
def center_x(self):
|
||||||
|
return self.start_x + (self.cols / 2) * 100
|
||||||
|
|
||||||
|
@property
|
||||||
|
def enemies(self):
|
||||||
|
return [col for row in self.grid for col in row if not col == None]
|
||||||
|
|
||||||
class Player(arcade.Sprite): # Not actually the player
|
class Player(arcade.Sprite): # Not actually the player
|
||||||
def __init__(self, x, y):
|
def __init__(self, x, y):
|
||||||
@@ -26,21 +115,19 @@ class Player(arcade.Sprite): # Not actually the player
|
|||||||
|
|
||||||
self.last_target_change = time.perf_counter()
|
self.last_target_change = time.perf_counter()
|
||||||
self.last_shoot = time.perf_counter()
|
self.last_shoot = time.perf_counter()
|
||||||
self.target = None
|
|
||||||
self.shoot = False
|
self.shoot = False
|
||||||
|
|
||||||
self.player_speed = 0
|
self.player_speed = 0
|
||||||
|
|
||||||
def update(self, model: PPO, enemies, bullets, width, height):
|
def update(self, model: PPO, enemy_formation, bullets, width, height, player_respawns_norm, enemy_respawns_norm):
|
||||||
if enemies:
|
if enemy_formation.enemies:
|
||||||
nearest_enemy = min(enemies, key=lambda e: abs(e.center_x - self.center_x))
|
nearest_enemy = min(enemy_formation.enemies, key=lambda e: abs(e.center_x - self.center_x))
|
||||||
enemy_x = (nearest_enemy.center_x - self.center_x) / width
|
enemy_x = (nearest_enemy.center_x - self.center_x) / width
|
||||||
enemy_y = (nearest_enemy.center_y - self.center_y) / height
|
enemy_y = (nearest_enemy.center_y - self.center_y) / height
|
||||||
else:
|
else:
|
||||||
enemy_x = 2
|
enemy_x = 2
|
||||||
enemy_y = 2
|
enemy_y = 2
|
||||||
|
|
||||||
enemy_count = len(enemies) / float(max(1, ENEMY_ROWS * ENEMY_COLS))
|
enemy_count = len(enemy_formation.enemies) / float(max(1, enemy_formation.rows * enemy_formation.cols))
|
||||||
player_x_norm = self.center_x / width
|
player_x_norm = self.center_x / width
|
||||||
|
|
||||||
curr_bullet = min(bullets, key=lambda b: abs(b.center_x - self.center_x) + abs(b.center_y - self.center_y)) if bullets else None
|
curr_bullet = min(bullets, key=lambda b: abs(b.center_x - self.center_x) + abs(b.center_y - self.center_y)) if bullets else None
|
||||||
@@ -51,18 +138,31 @@ class Player(arcade.Sprite): # Not actually the player
|
|||||||
curr_bx = 2.0
|
curr_bx = 2.0
|
||||||
curr_by = 2.0
|
curr_by = 2.0
|
||||||
|
|
||||||
lowest = max(enemies, key=lambda e: e.center_y) if enemies else None
|
lowest = max(enemy_formation.enemies, key=lambda e: e.center_y) if enemy_formation.enemies else None
|
||||||
if lowest is not None:
|
if lowest is not None:
|
||||||
lowest_dy = (lowest.center_y - self.center_y) / float(height)
|
lowest_dy = (lowest.center_y - self.center_y) / float(height)
|
||||||
else:
|
else:
|
||||||
lowest_dy = 2.0
|
lowest_dy = 2.0
|
||||||
|
|
||||||
enemy_dispersion = 0.0
|
enemy_dispersion = 0.0
|
||||||
if enemies:
|
if enemy_formation.enemies:
|
||||||
xs = np.array([e.center_x for e in enemies], dtype=np.float32)
|
xs = np.array([e.center_x for e in enemy_formation.enemies], dtype=np.float32)
|
||||||
enemy_dispersion = float(xs.std()) / float(width)
|
enemy_dispersion = float(xs.std()) / float(width)
|
||||||
|
|
||||||
obs = np.array([player_x_norm, enemy_x, enemy_y, lowest_dy, curr_bx, curr_by, self.player_speed, enemy_count, enemy_dispersion], dtype=np.float32)
|
obs = np.array([
|
||||||
|
player_x_norm,
|
||||||
|
enemy_x,
|
||||||
|
enemy_y,
|
||||||
|
lowest_dy,
|
||||||
|
curr_bx,
|
||||||
|
curr_by,
|
||||||
|
self.player_speed,
|
||||||
|
enemy_count,
|
||||||
|
enemy_dispersion,
|
||||||
|
time.perf_counter() - self.last_shoot >= PLAYER_ATTACK_SPEED,
|
||||||
|
player_respawns_norm,
|
||||||
|
enemy_respawns_norm
|
||||||
|
], dtype=np.float32)
|
||||||
action, _ = model.predict(obs, deterministic=True)
|
action, _ = model.predict(obs, deterministic=True)
|
||||||
|
|
||||||
self.prev_x = self.center_x
|
self.prev_x = self.center_x
|
||||||
@@ -71,6 +171,8 @@ class Player(arcade.Sprite): # Not actually the player
|
|||||||
elif action == 1:
|
elif action == 1:
|
||||||
self.center_x += PLAYER_SPEED
|
self.center_x += PLAYER_SPEED
|
||||||
elif action == 2:
|
elif action == 2:
|
||||||
|
pass
|
||||||
|
elif action == 3:
|
||||||
t = time.perf_counter()
|
t = time.perf_counter()
|
||||||
if t - self.last_shoot >= PLAYER_ATTACK_SPEED:
|
if t - self.last_shoot >= PLAYER_ATTACK_SPEED:
|
||||||
self.last_shoot = t
|
self.last_shoot = t
|
||||||
|
|||||||
BIN
invader_agent.zip
Normal file
BIN
invader_agent.zip
Normal file
Binary file not shown.
@@ -62,8 +62,8 @@ class Main(arcade.gui.UIView):
|
|||||||
self.settings_button.on_click = lambda event: self.settings()
|
self.settings_button.on_click = lambda event: self.settings()
|
||||||
|
|
||||||
def play(self):
|
def play(self):
|
||||||
from game.play import Game
|
from menus.mode_selector import ModeSelector
|
||||||
self.window.show_view(Game(self.pypresence_client))
|
self.window.show_view(ModeSelector(self.pypresence_client))
|
||||||
|
|
||||||
def settings(self):
|
def settings(self):
|
||||||
from menus.settings import Settings
|
from menus.settings import Settings
|
||||||
|
|||||||
72
menus/mode_selector.py
Normal file
72
menus/mode_selector.py
Normal file
@@ -0,0 +1,72 @@
|
|||||||
|
import arcade, arcade.gui
|
||||||
|
|
||||||
|
from utils.preload import button_texture, button_hovered_texture
|
||||||
|
from utils.constants import dropdown_style, button_style, DIFFICULTY_LEVELS, DIFFICULTY_SETTINGS
|
||||||
|
|
||||||
|
class ModeSelector(arcade.gui.UIView):
|
||||||
|
def __init__(self, pypresence_client):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.pypresence_client = pypresence_client
|
||||||
|
self.pypresence_client.update(state="Selecting Mode")
|
||||||
|
|
||||||
|
self.anchor = self.add_widget(arcade.gui.UIAnchorLayout(size_hint=(1, 1)))
|
||||||
|
self.box = self.anchor.add(arcade.gui.UIBoxLayout(size_hint=(0.75, 0.75), space_between=10), anchor_x="center", anchor_y="center")
|
||||||
|
|
||||||
|
self.settings = DIFFICULTY_LEVELS["Easy"]
|
||||||
|
self.setting_sliders = {}
|
||||||
|
self.setting_labels = {}
|
||||||
|
|
||||||
|
def on_show_view(self):
|
||||||
|
super().on_show_view()
|
||||||
|
|
||||||
|
self.back_button = self.anchor.add(arcade.gui.UITextureButton(texture=button_texture, texture_hovered=button_hovered_texture, text='<--', style=button_style, width=100, height=50), anchor_x="left", anchor_y="top", align_x=5, align_y=-5)
|
||||||
|
self.back_button.on_click = lambda event: self.main_exit()
|
||||||
|
|
||||||
|
self.box.add(arcade.gui.UILabel("Settings", font_size=32))
|
||||||
|
|
||||||
|
self.box.add(arcade.gui.UISpace(height=self.window.height / 80))
|
||||||
|
|
||||||
|
self.difficulty_selector = self.box.add(arcade.gui.UIDropdown(default="Easy", options=list(DIFFICULTY_LEVELS.keys()), active_style=dropdown_style, primary_style=dropdown_style, dropdown_style=dropdown_style, width=self.window.width / 2, height=self.window.height / 20))
|
||||||
|
self.difficulty_selector.on_change = lambda event: self.set_difficulty_values(event.new_value)
|
||||||
|
|
||||||
|
self.box.add(arcade.gui.UISpace(height=self.window.height / 80))
|
||||||
|
|
||||||
|
for key, data in DIFFICULTY_SETTINGS.items():
|
||||||
|
default, name, min_value, max_value = DIFFICULTY_LEVELS["Easy"][key], *data
|
||||||
|
|
||||||
|
label = self.box.add(arcade.gui.UILabel(text=f"{name}: {default}", font_size=14))
|
||||||
|
|
||||||
|
slider = self.box.add(arcade.gui.UISlider(value=default, min_value=min_value, max_value=max_value, step=1, width=self.window.width / 2, height=self.window.height / 25))
|
||||||
|
slider._render_steps = lambda surface: None
|
||||||
|
slider.on_event = lambda event: None # disable slider for difficulties
|
||||||
|
slider.on_click = lambda event: None # disable slider for difficulties
|
||||||
|
slider.on_change = lambda e, key=key: self.change_value(key, e.new_value)
|
||||||
|
|
||||||
|
self.setting_sliders[key] = slider
|
||||||
|
self.setting_labels[key] = label
|
||||||
|
|
||||||
|
self.play_button = self.box.add(arcade.gui.UITextureButton(text="Play", width=self.window.width / 2, height=self.window.height / 15, texture=button_texture, texture_hovered=button_hovered_texture, style=button_style))
|
||||||
|
self.play_button.on_click = lambda event: self.start_game()
|
||||||
|
|
||||||
|
def set_difficulty_values(self, difficulty):
|
||||||
|
for key, value in DIFFICULTY_LEVELS[difficulty].items():
|
||||||
|
self.settings[key] = value
|
||||||
|
self.setting_labels[key].text = f"{DIFFICULTY_SETTINGS[key][0]}: {value}"
|
||||||
|
self.setting_sliders[key].value = value
|
||||||
|
|
||||||
|
for slider in self.setting_sliders.values():
|
||||||
|
if difficulty != "Custom":
|
||||||
|
slider.on_event = lambda event: None
|
||||||
|
slider.on_click = lambda event: None
|
||||||
|
else:
|
||||||
|
slider.on_event = lambda event, slider=slider: arcade.gui.UISlider.on_event(slider, event)
|
||||||
|
slider.on_click = lambda event, slider=slider: arcade.gui.UISlider.on_click(slider, event)
|
||||||
|
|
||||||
|
def change_value(self, key, value):
|
||||||
|
self.settings[key] = int(value)
|
||||||
|
self.setting_labels[key].text = f"{DIFFICULTY_SETTINGS[key][0]}: {int(value)}"
|
||||||
|
|
||||||
|
def start_game(self):
|
||||||
|
from game.play import Game
|
||||||
|
self.window.show_view(Game(self.pypresence_client, self.settings))
|
||||||
@@ -1,20 +1,26 @@
|
|||||||
import arcade, arcade.gui, threading, io, os, time
|
import arcade, arcade.gui, threading, os, queue, time, shutil
|
||||||
|
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
from utils.constants import button_style, MODEL_SETTINGS, monitor_log_dir
|
from utils.constants import button_style, MODEL_SETTINGS, monitor_log_dir
|
||||||
from utils.preload import button_texture, button_hovered_texture
|
from utils.preload import button_texture, button_hovered_texture
|
||||||
|
from utils.rl import SpaceInvadersEnv
|
||||||
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
import numpy as np
|
|
||||||
import pandas as pd
|
|
||||||
|
|
||||||
from io import BytesIO
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
from io import BytesIO
|
||||||
|
|
||||||
from stable_baselines3 import PPO
|
from stable_baselines3 import PPO
|
||||||
from stable_baselines3.common.monitor import Monitor
|
from stable_baselines3.common.monitor import Monitor
|
||||||
from stable_baselines3.common.logger import configure
|
from stable_baselines3.common.logger import configure
|
||||||
|
from stable_baselines3.common.vec_env import DummyVecEnv
|
||||||
|
|
||||||
from utils.rl import SpaceInvadersEnv
|
def make_env(rank: int, seed: int = 0):
|
||||||
|
def _init():
|
||||||
|
env = SpaceInvadersEnv()
|
||||||
|
env = Monitor(env, filename=os.path.join(monitor_log_dir, f"monitor_{rank}.csv"))
|
||||||
|
return env
|
||||||
|
return _init
|
||||||
|
|
||||||
class TrainModel(arcade.gui.UIView):
|
class TrainModel(arcade.gui.UIView):
|
||||||
def __init__(self, pypresence_client):
|
def __init__(self, pypresence_client):
|
||||||
@@ -24,22 +30,25 @@ class TrainModel(arcade.gui.UIView):
|
|||||||
self.pypresence_client.update(state="Model Training")
|
self.pypresence_client.update(state="Model Training")
|
||||||
|
|
||||||
self.anchor = self.add_widget(arcade.gui.UIAnchorLayout(size_hint=(1, 1)))
|
self.anchor = self.add_widget(arcade.gui.UIAnchorLayout(size_hint=(1, 1)))
|
||||||
self.box = self.anchor.add(arcade.gui.UIBoxLayout(space_between=10))
|
self.box = self.anchor.add(arcade.gui.UIBoxLayout(space_between=5))
|
||||||
|
|
||||||
self.settings = {
|
self.settings = {
|
||||||
setting: data[0] # default value
|
setting: data[0]
|
||||||
for setting, data in MODEL_SETTINGS.items()
|
for setting, data in MODEL_SETTINGS.items()
|
||||||
}
|
}
|
||||||
|
|
||||||
self.labels = {}
|
self.labels = {}
|
||||||
|
|
||||||
self.training = False
|
self.training = False
|
||||||
self.training_text = ""
|
self.training_text = "Starting training..."
|
||||||
|
|
||||||
|
self.result_queue = queue.Queue()
|
||||||
|
self.training_thread = None
|
||||||
|
|
||||||
self.last_progress_update = time.perf_counter()
|
self.last_progress_update = time.perf_counter()
|
||||||
|
|
||||||
def on_show_view(self):
|
def on_show_view(self):
|
||||||
super().on_show_view()
|
super().on_show_view()
|
||||||
|
|
||||||
self.show_menu()
|
self.show_menu()
|
||||||
|
|
||||||
def main_exit(self):
|
def main_exit(self):
|
||||||
@@ -50,128 +59,188 @@ class TrainModel(arcade.gui.UIView):
|
|||||||
self.back_button = self.anchor.add(arcade.gui.UITextureButton(texture=button_texture, texture_hovered=button_hovered_texture, text='<--', style=button_style, width=100, height=50), anchor_x="left", anchor_y="top", align_x=5, align_y=-5)
|
self.back_button = self.anchor.add(arcade.gui.UITextureButton(texture=button_texture, texture_hovered=button_hovered_texture, text='<--', style=button_style, width=100, height=50), anchor_x="left", anchor_y="top", align_x=5, align_y=-5)
|
||||||
self.back_button.on_click = lambda event: self.main_exit()
|
self.back_button.on_click = lambda event: self.main_exit()
|
||||||
|
|
||||||
self.box.add(arcade.gui.UILabel("Settings", font_size=36))
|
self.box.add(arcade.gui.UILabel("Settings", font_size=32))
|
||||||
|
|
||||||
for setting, data in MODEL_SETTINGS.items():
|
for setting, data in MODEL_SETTINGS.items():
|
||||||
default, min_value, max_value, step = data
|
default, min_value, max_value, step = data
|
||||||
label = self.box.add(arcade.gui.UILabel(text=f"{setting.replace('_', ' ').capitalize()}: {default}", font_size=18))
|
is_int = setting == "n_envs" or (abs(step - 1) < 1e-6 and abs(min_value - round(min_value)) < 1e-6)
|
||||||
|
|
||||||
|
val_text = str(int(default)) if is_int else str(default)
|
||||||
|
label = self.box.add(arcade.gui.UILabel(text=f"{setting.replace('_', ' ').capitalize()}: {val_text}", font_size=14))
|
||||||
|
|
||||||
slider = self.box.add(arcade.gui.UISlider(value=default, min_value=min_value, max_value=max_value, step=step, width=self.window.width / 2, height=self.window.height / 25))
|
slider = self.box.add(arcade.gui.UISlider(value=default, min_value=min_value, max_value=max_value, step=step, width=self.window.width / 2, height=self.window.height / 25))
|
||||||
slider._render_steps = lambda surface: None
|
slider._render_steps = lambda surface: None
|
||||||
slider.on_change = lambda e, key=setting: self.change_value(key, e.new_value)
|
slider.on_change = lambda e, key=setting, is_int_slider=is_int: self.change_value(key, e.new_value, is_int_slider)
|
||||||
|
|
||||||
self.labels[setting] = label
|
self.labels[setting] = label
|
||||||
|
|
||||||
train_button = self.box.add(arcade.gui.UITextureButton(width=self.window.width / 2, height=self.window.height / 10, text="Train", style=button_style, texture=button_texture, texture_hovered=button_hovered_texture))
|
train_button = self.box.add(arcade.gui.UITextureButton(width=self.window.width / 2, height=self.window.height / 10, text="Train", style=button_style, texture=button_texture, texture_hovered=button_hovered_texture))
|
||||||
train_button.on_click = lambda e: self.start_training()
|
train_button.on_click = lambda e: self.start_training()
|
||||||
|
|
||||||
def change_value(self, key, value):
|
def change_value(self, key, value, is_int=False):
|
||||||
self.labels[key].text = f"{key.replace('_', ' ').capitalize()}: {self.round_near_int(value)}"
|
if is_int:
|
||||||
self.settings[key] = self.round_near_int(value)
|
val = int(round(value))
|
||||||
|
self.settings[key] = val
|
||||||
|
self.labels[key].text = f"{key.replace('_', ' ').capitalize()}: {val}"
|
||||||
|
else:
|
||||||
|
val = self.round_near_int(value)
|
||||||
|
self.settings[key] = val
|
||||||
|
self.labels[key].text = f"{key.replace('_', ' ').capitalize()}: {val}"
|
||||||
|
|
||||||
def start_training(self):
|
def start_training(self):
|
||||||
self.box.clear()
|
self.box.clear()
|
||||||
|
|
||||||
self.training = True
|
self.training_text = "Starting training..."
|
||||||
|
self.training_label = self.box.add(arcade.gui.UILabel("Starting training...", font_size=16, multiline=True, width=self.window.width / 2, height=self.window.height / 2))
|
||||||
self.training_label = self.box.add(arcade.gui.UILabel("No Output yet.", font_size=16, multiline=True, width=self.window.width / 2, height=self.window.height / 2))
|
|
||||||
|
|
||||||
self.plot_image_widget = self.box.add(arcade.gui.UIImage(texture=arcade.Texture.create_empty("empty", (1, 1))))
|
self.plot_image_widget = self.box.add(arcade.gui.UIImage(texture=arcade.Texture.create_empty("empty", (1, 1))))
|
||||||
self.plot_image_widget.visible = False
|
self.plot_image_widget.visible = False
|
||||||
|
|
||||||
threading.Thread(target=self.train, daemon=True).start()
|
self.training_thread = threading.Thread(target=self.train, daemon=True)
|
||||||
|
self.training_thread.start()
|
||||||
|
|
||||||
def on_update(self, delta_time):
|
def on_update(self, delta_time):
|
||||||
if self.training and os.path.exists(os.path.join("training_logs", "progress.csv")) and time.perf_counter() - self.last_progress_update >= 0.5:
|
|
||||||
self.last_progress_update = time.perf_counter()
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
progress_df = pd.read_csv(os.path.join("training_logs", "progress.csv"))
|
result = self.result_queue.get_nowait()
|
||||||
except pd.errors.EmptyDataError:
|
|
||||||
return
|
|
||||||
|
|
||||||
progress_text = ""
|
if result["type"] == "text":
|
||||||
|
self.training_text = result["message"]
|
||||||
|
|
||||||
for key, value in progress_df.items():
|
elif result["type"] == "plot":
|
||||||
progress_text += f"{key}: {round(value.iloc[-1], 6)}\n"
|
self.plot_image_widget.texture = result["image"]
|
||||||
|
self.plot_image_widget.width = result["image"].width
|
||||||
|
self.plot_image_widget.height = result["image"].height
|
||||||
|
self.plot_image_widget.trigger_render()
|
||||||
|
self.plot_image_widget.visible = True
|
||||||
|
|
||||||
self.training_text = progress_text
|
elif result["type"] == "finished":
|
||||||
|
self.training = False
|
||||||
|
self.training_text = "Training finished."
|
||||||
|
|
||||||
|
except queue.Empty:
|
||||||
|
if self.training and os.path.exists(os.path.join("training_logs", "progress.csv")) and all([os.path.exists(os.path.join(monitor_log_dir, f"monitor_{i}.csv.monitor.csv")) for i in range(int(self.settings["n_envs"]))]) and time.perf_counter() - self.last_progress_update >= 1:
|
||||||
|
self.last_progress_update = time.perf_counter()
|
||||||
|
self.plot_results()
|
||||||
|
|
||||||
if hasattr(self, "training_label"):
|
if hasattr(self, "training_label"):
|
||||||
self.training_label.text = self.training_text
|
self.training_label.text = self.training_text
|
||||||
|
|
||||||
def round_near_int(self, x, tol=1e-4):
|
def round_near_int(self, x, tol=1e-4):
|
||||||
nearest = round(x)
|
nearest = round(x)
|
||||||
|
|
||||||
if abs(x - nearest) < tol:
|
if abs(x - nearest) < tol:
|
||||||
return nearest
|
return nearest
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def train(self):
|
def train(self):
|
||||||
os.makedirs(monitor_log_dir, exist_ok=True)
|
if os.path.exists(monitor_log_dir):
|
||||||
env = Monitor(SpaceInvadersEnv(), filename=os.path.join(monitor_log_dir, "monitor.csv"))
|
shutil.rmtree(monitor_log_dir)
|
||||||
|
os.makedirs(monitor_log_dir)
|
||||||
|
|
||||||
|
n_envs = int(self.settings["n_envs"])
|
||||||
|
env = DummyVecEnv([make_env(i) for i in range(n_envs)])
|
||||||
|
|
||||||
|
n_steps = int(self.settings["n_steps"])
|
||||||
|
batch_size = int(self.settings["batch_size"])
|
||||||
|
|
||||||
|
total_steps_per_rollout = n_steps * max(1, n_envs)
|
||||||
|
if total_steps_per_rollout % batch_size != 0:
|
||||||
|
batch_size = max(64, total_steps_per_rollout // max(1, total_steps_per_rollout // batch_size))
|
||||||
|
print(f"Warning: Adjusting batch size to {batch_size} for {n_envs} envs.")
|
||||||
|
|
||||||
model = PPO(
|
model = PPO(
|
||||||
"MlpPolicy",
|
"MlpPolicy",
|
||||||
env,
|
env,
|
||||||
n_steps=self.settings["n_steps"],
|
n_steps=n_steps,
|
||||||
batch_size=self.settings["batch_size"],
|
batch_size=batch_size,
|
||||||
n_epochs=self.settings["n_epochs"],
|
n_epochs=int(self.settings["n_epochs"]),
|
||||||
learning_rate=self.settings["learning_rate"],
|
learning_rate=float(self.settings["learning_rate"]),
|
||||||
verbose=1,
|
verbose=1,
|
||||||
device="cpu",
|
device="cpu",
|
||||||
gamma=self.settings["gamma"],
|
gamma=float(self.settings["gamma"]),
|
||||||
ent_coef=self.settings["ent_coef"],
|
ent_coef=float(self.settings["ent_coef"]),
|
||||||
clip_range=self.settings["clip_range"],
|
clip_range=float(self.settings["clip_range"]),
|
||||||
)
|
|
||||||
|
|
||||||
new_logger = configure(
|
|
||||||
folder=monitor_log_dir, format_strings=["csv"]
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
new_logger = configure(folder=monitor_log_dir, format_strings=["csv"])
|
||||||
model.set_logger(new_logger)
|
model.set_logger(new_logger)
|
||||||
|
|
||||||
model.learn(self.settings["learning_steps"])
|
try:
|
||||||
|
self.training = True
|
||||||
|
model.learn(int(self.settings["learning_steps"]))
|
||||||
model.save("invader_agent")
|
model.save("invader_agent")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error during training: {e}")
|
||||||
|
self.result_queue.put({"type": "text", "message": f"Error:\n{e}"})
|
||||||
|
finally:
|
||||||
|
try:
|
||||||
|
env.close()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
self.result_queue.put({"type": "finished"})
|
||||||
|
|
||||||
self.training = False
|
def plot_results(self):
|
||||||
|
try:
|
||||||
|
reward_df = pd.read_csv(os.path.join(monitor_log_dir, "progress.csv"))
|
||||||
|
except pd.errors.EmptyDataError:
|
||||||
|
return
|
||||||
|
|
||||||
self.plot_results(os.path.join(monitor_log_dir, "monitor.csv"), os.path.join(monitor_log_dir, "progress.csv"))
|
all_monitor_files = [os.path.join(monitor_log_dir, f) for f in os.listdir(monitor_log_dir) if f.startswith("monitor_") and f.endswith(".csv")]
|
||||||
|
try:
|
||||||
|
df_list = [pd.read_csv(f, skiprows=1) for f in all_monitor_files]
|
||||||
|
except pd.errors.EmptyDataError:
|
||||||
|
return
|
||||||
|
|
||||||
def plot_results(self, log_path, loss_log_path):
|
monitor_df = pd.concat(df_list).sort_values(by='t')
|
||||||
df = pd.read_csv(log_path, skiprows=1)
|
monitor_df['total_timesteps'] = monitor_df['l'].cumsum()
|
||||||
|
|
||||||
|
loss_log_path = os.path.join(monitor_log_dir, "progress.csv")
|
||||||
|
loss_df = None
|
||||||
|
if os.path.exists(loss_log_path):
|
||||||
|
try:
|
||||||
|
loss_df = pd.read_csv(loss_log_path)
|
||||||
|
except Exception:
|
||||||
|
loss_df = None
|
||||||
|
|
||||||
fig, axes = plt.subplots(2, 1, figsize=(6, 8), dpi=100)
|
fig, axes = plt.subplots(2, 1, figsize=(6, 8), dpi=100)
|
||||||
|
|
||||||
loss_df = pd.read_csv(loss_log_path)
|
if monitor_df is not None and 'total_timesteps' in monitor_df.columns and 'r' in monitor_df.columns:
|
||||||
|
axes[0].plot(monitor_df['total_timesteps'], monitor_df['r'].rolling(window=10).mean(), label='Episodic Reward (Rolling 10)')
|
||||||
|
elif reward_df is not None and 'time/total_timesteps' in reward_df.columns and 'rollout/ep_rew_mean' in reward_df.columns:
|
||||||
|
axes[0].plot(reward_df['time/total_timesteps'], reward_df['rollout/ep_rew_mean'], label='Ep reward mean')
|
||||||
|
else:
|
||||||
|
axes[0].text(0.5, 0.5, "No reward data available", horizontalalignment='center', verticalalignment='center')
|
||||||
|
|
||||||
axes[0].plot(np.cumsum(df['l']), df['r'], label='Reward')
|
|
||||||
axes[0].set_title('PPO Training: Episodic Reward')
|
axes[0].set_title('PPO Training: Episodic Reward')
|
||||||
axes[0].set_xlabel('Total Timesteps')
|
axes[0].set_xlabel('Total Timesteps')
|
||||||
axes[0].set_ylabel('Reward')
|
axes[0].set_ylabel('Reward')
|
||||||
axes[0].grid(True)
|
axes[0].grid(True)
|
||||||
|
|
||||||
axes[1].plot(loss_df['time/total_timesteps'], loss_df['train/policy_gradient_loss'], label='Policy Gradient Loss')
|
axes[1].set_title('PPO Training: Loss & Variance')
|
||||||
axes[1].plot(loss_df['time/total_timesteps'], loss_df['train/value_loss'], label='Value Loss')
|
|
||||||
axes[1].plot(loss_df['time/total_timesteps'], loss_df['train/explained_variance'], label='Explained Variance')
|
|
||||||
axes[1].set_title('PPO Training: Loss Functions')
|
|
||||||
axes[1].set_xlabel('Total Timesteps')
|
axes[1].set_xlabel('Total Timesteps')
|
||||||
axes[1].set_ylabel('Loss Value')
|
axes[1].set_ylabel('Value')
|
||||||
axes[1].legend()
|
|
||||||
axes[1].grid(True)
|
axes[1].grid(True)
|
||||||
|
|
||||||
|
if loss_df is not None and 'time/total_timesteps' in loss_df.columns and 'train/policy_gradient_loss' in loss_df.columns and 'train/value_loss' in loss_df.columns and 'train/explained_variance' in loss_df.columns:
|
||||||
|
tcol = 'time/total_timesteps'
|
||||||
|
axes[1].plot(loss_df[tcol], loss_df['train/policy_gradient_loss'], label='Policy Loss')
|
||||||
|
axes[1].plot(loss_df[tcol], loss_df['train/value_loss'], label='Value Loss')
|
||||||
|
axes[1].plot(loss_df[tcol], loss_df['train/explained_variance'], label='Explained Variance')
|
||||||
|
|
||||||
|
axes[1].legend()
|
||||||
|
else:
|
||||||
|
axes[1].text(0.5, 0.5, "No loss/variance data available", horizontalalignment='center', verticalalignment='center')
|
||||||
|
|
||||||
plt.tight_layout()
|
plt.tight_layout()
|
||||||
|
|
||||||
buffer = BytesIO()
|
buffer = BytesIO()
|
||||||
plt.savefig(buffer, format='png')
|
plt.savefig(buffer, format='png', bbox_inches='tight')
|
||||||
buffer.seek(0)
|
buffer.seek(0)
|
||||||
plt.close(fig)
|
plt.close(fig)
|
||||||
|
|
||||||
plot_texture = arcade.Texture(Image.open(buffer))
|
pil_img = Image.open(buffer).convert("RGBA")
|
||||||
|
|
||||||
self.plot_image_widget.texture = plot_texture
|
plot_texture = arcade.Texture(pil_img)
|
||||||
self.plot_image_widget.size_hint = (None, None)
|
|
||||||
self.plot_image_widget.width = plot_texture.width
|
|
||||||
self.plot_image_widget.height = plot_texture.height
|
|
||||||
|
|
||||||
self.plot_image_widget.visible = True
|
self.result_queue.put({"type": "plot", "image": plot_texture})
|
||||||
self.training_text = "Training finished. Plot displayed."
|
|
||||||
|
|||||||
2
run.py
2
run.py
@@ -29,7 +29,7 @@ timestamp = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
|
|||||||
log_filename = f"debug_{timestamp}.log"
|
log_filename = f"debug_{timestamp}.log"
|
||||||
logging.basicConfig(filename=f'{os.path.join(log_dir, log_filename)}', format='%(asctime)s %(name)s %(levelname)s: %(message)s', level=logging.DEBUG)
|
logging.basicConfig(filename=f'{os.path.join(log_dir, log_filename)}', format='%(asctime)s %(name)s %(levelname)s: %(message)s', level=logging.DEBUG)
|
||||||
|
|
||||||
for logger_name_to_disable in ['arcade']:
|
for logger_name_to_disable in ['arcade', "matplotlib", "matplotlib.fontmanager", "PIL"]:
|
||||||
logging.getLogger(logger_name_to_disable).propagate = False
|
logging.getLogger(logger_name_to_disable).propagate = False
|
||||||
logging.getLogger(logger_name_to_disable).disabled = True
|
logging.getLogger(logger_name_to_disable).disabled = True
|
||||||
|
|
||||||
|
|||||||
@@ -3,25 +3,68 @@ from arcade.types import Color
|
|||||||
from arcade.gui.widgets.buttons import UITextureButtonStyle, UIFlatButtonStyle
|
from arcade.gui.widgets.buttons import UITextureButtonStyle, UIFlatButtonStyle
|
||||||
from arcade.gui.widgets.slider import UISliderStyle
|
from arcade.gui.widgets.slider import UISliderStyle
|
||||||
|
|
||||||
ENEMY_ROWS = 3
|
ENEMY_SPEED = 5
|
||||||
ENEMY_COLS = 13
|
ENEMY_ATTACK_SPEED = 0.75
|
||||||
|
|
||||||
PLAYER_SPEED = 5 # not actually player
|
PLAYER_SPEED = 5 # not actually player
|
||||||
PLAYER_ATTACK_SPEED = 0.75
|
PLAYER_ATTACK_SPEED = 0.75
|
||||||
|
|
||||||
BULLET_SPEED = 3
|
BULLET_SPEED = 5
|
||||||
BULLET_RADIUS = 10
|
BULLET_RADIUS = 15
|
||||||
|
|
||||||
# default, min, max, step
|
# default, min, max, step
|
||||||
MODEL_SETTINGS = {
|
MODEL_SETTINGS = {
|
||||||
"n_steps": [2048, 256, 8192, 256],
|
"n_steps": [1024, 256, 8192, 256],
|
||||||
"batch_size": [64, 16, 512, 16],
|
"batch_size": [128, 16, 512, 16],
|
||||||
"n_epochs": [10, 1, 50, 1],
|
"n_epochs": [10, 1, 50, 1],
|
||||||
"learning_rate": [3e-4, 1e-5, 1e-2, 1e-5],
|
"learning_rate": [3e-4, 1e-5, 1e-2, 1e-5],
|
||||||
"gamma": [0.99, 0.8, 0.9999, 0.001],
|
"gamma": [0.99, 0.8, 0.9999, 0.001],
|
||||||
"ent_coef": [0.01, 0.0, 0.1, 0.001],
|
"ent_coef": [0.015, 0.0, 0.1, 0.001],
|
||||||
"clip_range": [0.2, 0.1, 0.4, 0.01],
|
"clip_range": [0.2, 0.1, 0.4, 0.01],
|
||||||
"learning_steps": [500_000, 50_000, 25_000_000, 50_000]
|
"learning_steps": [1_000_000, 50_000, 25_000_000, 50_000],
|
||||||
|
"n_envs": (12, 1, 128, 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
DIFFICULTY_SETTINGS = {
|
||||||
|
"enemy_rows": ["Enemy Rows", 1, 6],
|
||||||
|
"enemy_cols": ["Enemy Columns", 1, 7],
|
||||||
|
"enemy_respawns": ["Enemy Respawns", 1, 5],
|
||||||
|
"player_count": ["Player Count", 1, 10],
|
||||||
|
"player_respawns": ["Player Respawns", 1, 5]
|
||||||
|
}
|
||||||
|
|
||||||
|
DIFFICULTY_LEVELS = {
|
||||||
|
"Easy": {
|
||||||
|
"enemy_rows": 3,
|
||||||
|
"enemy_cols": 4,
|
||||||
|
"enemy_respawns": 5,
|
||||||
|
"player_count": 2,
|
||||||
|
"player_respawns": 2
|
||||||
|
},
|
||||||
|
"Medium": {
|
||||||
|
"enemy_rows": 3,
|
||||||
|
"enemy_cols": 5,
|
||||||
|
"enemy_respawns": 4,
|
||||||
|
"player_count": 4,
|
||||||
|
"player_respawns": 3
|
||||||
|
},
|
||||||
|
"Hard": {
|
||||||
|
"enemy_rows": 4,
|
||||||
|
"enemy_cols": 6,
|
||||||
|
"enemy_respawns": 3,
|
||||||
|
"player_count": 6,
|
||||||
|
"player_respawns": 4
|
||||||
|
},
|
||||||
|
"Extra Hard": {
|
||||||
|
"enemy_rows": 6,
|
||||||
|
"enemy_cols": 7,
|
||||||
|
"enemy_respawns": 2,
|
||||||
|
"player_count": 8,
|
||||||
|
"player_respawns": 5
|
||||||
|
},
|
||||||
|
"Custom": {
|
||||||
|
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
menu_background_color = (30, 30, 47)
|
menu_background_color = (30, 30, 47)
|
||||||
|
|||||||
274
utils/rl.py
274
utils/rl.py
@@ -1,78 +1,87 @@
|
|||||||
import gymnasium as gym
|
import gymnasium as gym
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import arcade
|
import arcade
|
||||||
import time
|
|
||||||
import random
|
import random
|
||||||
|
|
||||||
from game.sprites import Enemy, Player, Bullet
|
from game.sprites import EnemyFormation, Player, Bullet
|
||||||
from utils.constants import PLAYER_SPEED, BULLET_SPEED, PLAYER_ATTACK_SPEED, ENEMY_ROWS, ENEMY_COLS
|
from utils.constants import PLAYER_SPEED, BULLET_SPEED, ENEMY_SPEED, DIFFICULTY_LEVELS
|
||||||
|
|
||||||
class SpaceInvadersEnv(gym.Env):
|
class SpaceInvadersEnv(gym.Env):
|
||||||
def __init__(self, width=800, height=600):
|
def __init__(self, width=800, height=600, difficulty="Hard"):
|
||||||
self.width = width
|
self.width = width
|
||||||
self.height = height
|
self.height = height
|
||||||
|
|
||||||
self.action_space = gym.spaces.Discrete(3)
|
self.action_space = gym.spaces.Discrete(4)
|
||||||
self.observation_space = gym.spaces.Box(low=-2.0, high=2.0, shape=(9,), dtype=np.float32)
|
self.observation_space = gym.spaces.Box(low=-2.0, high=2.0, shape=(12,), dtype=np.float32)
|
||||||
|
|
||||||
|
if difficulty not in DIFFICULTY_LEVELS:
|
||||||
|
raise ValueError(f"Unknown difficulty: {difficulty}. Available: {list(DIFFICULTY_LEVELS.keys())}")
|
||||||
|
|
||||||
|
self.difficulty_settings = DIFFICULTY_LEVELS[difficulty]
|
||||||
|
|
||||||
self.enemies = []
|
|
||||||
self.bullets = []
|
self.bullets = []
|
||||||
self.dir_history = []
|
|
||||||
self.last_shot = 0.0
|
|
||||||
self.player = None
|
self.player = None
|
||||||
self.prev_x = 0.0
|
self.enemy_formation = None
|
||||||
self.player_speed = 0.0
|
self.player_speed = 0.0
|
||||||
self.prev_bx = 2.0
|
self.max_steps = 2000
|
||||||
self.steps_since_direction_change = 0
|
|
||||||
self.last_direction = 0
|
|
||||||
self.max_steps = 1000
|
|
||||||
self.current_step = 0
|
self.current_step = 0
|
||||||
|
self.enemies_killed = 0
|
||||||
|
self.enemy_move_speed = ENEMY_SPEED
|
||||||
|
self.player_respawns = self.difficulty_settings["player_respawns"]
|
||||||
|
self.enemy_respawns = self.difficulty_settings["enemy_respawns"]
|
||||||
|
self.player_respawns_remaining = 0
|
||||||
|
self.enemy_respawns_remaining = 0
|
||||||
|
self.player_alive = True
|
||||||
|
|
||||||
|
self.player_attack_cooldown_steps = 5
|
||||||
|
self.current_cooldown = 0
|
||||||
|
|
||||||
def reset(self, seed=None, options=None):
|
def reset(self, seed=None, options=None):
|
||||||
if seed is not None:
|
if seed is not None:
|
||||||
np.random.seed(seed)
|
np.random.seed(seed)
|
||||||
random.seed(seed)
|
random.seed(seed)
|
||||||
|
|
||||||
self.enemies = []
|
|
||||||
self.bullets = []
|
self.bullets = []
|
||||||
self.dir_history = []
|
|
||||||
self.player = Player(self.width / 2 + random.randint(int(-self.width / 3), int(self.width / 3)), 100)
|
self.player = Player(self.width / 2 + random.randint(int(-self.width / 3), int(self.width / 3)), 100)
|
||||||
self.prev_x = self.player.center_x
|
|
||||||
self.player_speed = 0.0
|
self.player_speed = 0.0
|
||||||
self.prev_bx = 2.0
|
|
||||||
self.steps_since_direction_change = 0
|
|
||||||
self.last_direction = 0
|
|
||||||
self.current_step = 0
|
self.current_step = 0
|
||||||
|
self.enemies_killed = 0
|
||||||
|
self.player_respawns_remaining = self.player_respawns
|
||||||
|
self.enemy_respawns_remaining = self.enemy_respawns
|
||||||
|
self.player_alive = True
|
||||||
|
self.current_cooldown = 0
|
||||||
|
|
||||||
start_x = self.width * 0.15
|
start_x = self.width * 0.15
|
||||||
start_y = self.height * 0.9
|
start_y = self.height * 0.9
|
||||||
|
|
||||||
for r in range(ENEMY_ROWS):
|
self.enemy_formation = EnemyFormation(start_x, start_y, None,
|
||||||
for c in range(ENEMY_COLS):
|
self.difficulty_settings["enemy_rows"],
|
||||||
e = Enemy(start_x + c * 100, start_y - r * 100)
|
self.difficulty_settings["enemy_cols"])
|
||||||
self.enemies.append(e)
|
|
||||||
|
|
||||||
self.last_shot = time.perf_counter()
|
|
||||||
return self._obs(), {}
|
return self._obs(), {}
|
||||||
|
|
||||||
def _nearest_enemy(self):
|
def _nearest_enemy(self):
|
||||||
if not self.enemies:
|
if not self.enemy_formation.enemies:
|
||||||
return None
|
return None
|
||||||
return min(self.enemies, key=lambda e: abs(e.center_x - self.player.center_x))
|
|
||||||
|
return min(self.enemy_formation.enemies, key=lambda e: abs(e.center_x - self.player.center_x))
|
||||||
|
|
||||||
def _lowest_enemy(self):
|
def _lowest_enemy(self):
|
||||||
if not self.enemies:
|
if not self.enemy_formation.enemies:
|
||||||
return None
|
return None
|
||||||
return max(self.enemies, key=lambda e: e.center_y)
|
|
||||||
|
return min(self.enemy_formation.enemies, key=lambda e: e.center_y)
|
||||||
|
|
||||||
def _nearest_enemy_bullet(self):
|
def _nearest_enemy_bullet(self):
|
||||||
enemy_bullets = [b for b in self.bullets if b.direction_y == -1]
|
enemy_bullets = [b for b in self.bullets if b.direction_y == -1]
|
||||||
|
|
||||||
if not enemy_bullets:
|
if not enemy_bullets:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
return min(enemy_bullets, key=lambda b: abs(b.center_x - self.player.center_x) + abs(b.center_y - self.player.center_y))
|
return min(enemy_bullets, key=lambda b: abs(b.center_x - self.player.center_x) + abs(b.center_y - self.player.center_y))
|
||||||
|
|
||||||
def _obs(self):
|
def _obs(self):
|
||||||
if self.enemies:
|
if self.enemy_formation.enemies and self.player_alive:
|
||||||
nearest = self._nearest_enemy()
|
nearest = self._nearest_enemy()
|
||||||
enemy_x = (nearest.center_x - self.player.center_x) / float(self.width)
|
enemy_x = (nearest.center_x - self.player.center_x) / float(self.width)
|
||||||
enemy_y = (nearest.center_y - self.player.center_y) / float(self.height)
|
enemy_y = (nearest.center_y - self.player.center_y) / float(self.height)
|
||||||
@@ -81,31 +90,60 @@ class SpaceInvadersEnv(gym.Env):
|
|||||||
enemy_y = 2.0
|
enemy_y = 2.0
|
||||||
|
|
||||||
lowest = self._lowest_enemy()
|
lowest = self._lowest_enemy()
|
||||||
|
if lowest is not None and self.player_alive:
|
||||||
if lowest is not None:
|
|
||||||
lowest_dy = (lowest.center_y - self.player.center_y) / float(self.height)
|
lowest_dy = (lowest.center_y - self.player.center_y) / float(self.height)
|
||||||
else:
|
else:
|
||||||
lowest_dy = 2.0
|
lowest_dy = 2.0
|
||||||
|
|
||||||
nb = self._nearest_enemy_bullet()
|
nb = self._nearest_enemy_bullet()
|
||||||
if nb is not None:
|
if nb is not None and self.player_alive:
|
||||||
bx = (nb.center_x - self.player.center_x) / float(self.width)
|
bx = (nb.center_x - self.player.center_x) / float(self.width)
|
||||||
by = (nb.center_y - self.player.center_y) / float(self.height)
|
by = (nb.center_y - self.player.center_y) / float(self.height)
|
||||||
else:
|
else:
|
||||||
bx = 2.0
|
bx = 2.0
|
||||||
by = 2.0
|
by = 2.0
|
||||||
|
|
||||||
enemy_count = len(self.enemies) / float(max(1, ENEMY_ROWS * ENEMY_COLS))
|
enemy_count = len(self.enemy_formation.enemies) / float(max(1, self.difficulty_settings["enemy_rows"] * self.difficulty_settings["enemy_cols"]))
|
||||||
player_x_norm = self.player.center_x / float(self.width)
|
player_x_norm = self.player.center_x / float(self.width) if self.player_alive else 0.5
|
||||||
enemy_dispersion = 0.0
|
|
||||||
|
|
||||||
if self.enemies:
|
enemy_dispersion = 0.0
|
||||||
xs = np.array([e.center_x for e in self.enemies], dtype=np.float32)
|
if self.enemy_formation.enemies:
|
||||||
|
xs = np.array([e.center_x for e in self.enemy_formation.enemies], dtype=np.float32)
|
||||||
enemy_dispersion = float(xs.std()) / float(self.width)
|
enemy_dispersion = float(xs.std()) / float(self.width)
|
||||||
|
|
||||||
obs = np.array([player_x_norm, enemy_x, enemy_y, lowest_dy, bx, by, self.player_speed, enemy_count, enemy_dispersion], dtype=np.float32)
|
can_shoot = 1.0 if (self.player_alive and self.current_cooldown <= 0) else 0.0
|
||||||
|
|
||||||
|
player_respawns_norm = self.player_respawns_remaining / float(max(1, self.player_respawns))
|
||||||
|
enemy_respawns_norm = self.enemy_respawns_remaining / float(max(1, self.enemy_respawns))
|
||||||
|
|
||||||
|
obs = np.array([
|
||||||
|
player_x_norm,
|
||||||
|
enemy_x,
|
||||||
|
enemy_y,
|
||||||
|
lowest_dy,
|
||||||
|
bx,
|
||||||
|
by,
|
||||||
|
self.player_speed,
|
||||||
|
enemy_count,
|
||||||
|
enemy_dispersion,
|
||||||
|
can_shoot,
|
||||||
|
player_respawns_norm,
|
||||||
|
enemy_respawns_norm
|
||||||
|
], dtype=np.float32)
|
||||||
|
|
||||||
return obs
|
return obs
|
||||||
|
|
||||||
|
def _respawn_player(self):
|
||||||
|
self.player = Player(self.width / 2 + random.randint(int(-self.width / 3), int(self.width / 3)), 100)
|
||||||
|
self.player_alive = True
|
||||||
|
self.bullets = [b for b in self.bullets if b.direction_y == 1]
|
||||||
|
self.current_cooldown = 0
|
||||||
|
|
||||||
|
def _respawn_enemies(self):
|
||||||
|
self.enemy_formation.start_x = self.width * 0.15
|
||||||
|
self.enemy_formation.start_y = self.height * 0.9
|
||||||
|
self.enemy_formation.create_formation()
|
||||||
|
|
||||||
def step(self, action):
|
def step(self, action):
|
||||||
reward = 0.0
|
reward = 0.0
|
||||||
terminated = False
|
terminated = False
|
||||||
@@ -115,113 +153,125 @@ class SpaceInvadersEnv(gym.Env):
|
|||||||
if self.current_step >= self.max_steps:
|
if self.current_step >= self.max_steps:
|
||||||
truncated = True
|
truncated = True
|
||||||
|
|
||||||
nearest = self._nearest_enemy()
|
if self.current_cooldown > 0:
|
||||||
if nearest is not None:
|
self.current_cooldown -= 1
|
||||||
enemy_x = (nearest.center_x - self.player.center_x) / float(self.width)
|
|
||||||
else:
|
|
||||||
enemy_x = 2.0
|
|
||||||
|
|
||||||
|
if self.player_alive:
|
||||||
prev_x = self.player.center_x
|
prev_x = self.player.center_x
|
||||||
current_action_dir = 0
|
|
||||||
|
|
||||||
if action == 0:
|
if action == 0:
|
||||||
self.player.center_x -= PLAYER_SPEED
|
self.player.center_x -= PLAYER_SPEED
|
||||||
current_action_dir = -1
|
|
||||||
elif action == 1:
|
elif action == 1:
|
||||||
self.player.center_x += PLAYER_SPEED
|
self.player.center_x += PLAYER_SPEED
|
||||||
current_action_dir = 1
|
|
||||||
elif action == 2:
|
elif action == 2:
|
||||||
t = time.perf_counter()
|
pass
|
||||||
if t - self.last_shot >= PLAYER_ATTACK_SPEED:
|
elif action == 3:
|
||||||
self.last_shot = t
|
if self.current_cooldown <= 0:
|
||||||
|
self.current_cooldown = self.player_attack_cooldown_steps
|
||||||
|
reward += 0.01
|
||||||
b = Bullet(self.player.center_x, self.player.center_y, 1)
|
b = Bullet(self.player.center_x, self.player.center_y, 1)
|
||||||
|
|
||||||
self.bullets.append(b)
|
self.bullets.append(b)
|
||||||
|
else:
|
||||||
|
reward -= 0.05
|
||||||
|
|
||||||
if enemy_x != 2.0 and abs(enemy_x) < 0.04:
|
if self.enemy_formation.enemies:
|
||||||
|
nearest = self._nearest_enemy()
|
||||||
|
alignment = abs(nearest.center_x - self.player.center_x) / self.width
|
||||||
|
if alignment < 0.025:
|
||||||
reward += 0.3
|
reward += 0.3
|
||||||
elif enemy_x != 2.0 and abs(enemy_x) < 0.1:
|
|
||||||
reward += 0.1
|
|
||||||
|
|
||||||
if self.player.center_x > self.width:
|
|
||||||
self.player.center_x = self.width
|
|
||||||
elif self.player.center_x < 0:
|
|
||||||
self.player.center_x = 0
|
|
||||||
|
|
||||||
|
self.player.center_x = np.clip(self.player.center_x, 0, self.width)
|
||||||
self.player_speed = (self.player.center_x - prev_x) / max(1e-6, PLAYER_SPEED)
|
self.player_speed = (self.player.center_x - prev_x) / max(1e-6, PLAYER_SPEED)
|
||||||
|
|
||||||
if current_action_dir != 0:
|
if self.enemy_formation.enemies and self.player_alive:
|
||||||
if self.last_direction != 0 and current_action_dir != self.last_direction:
|
if self.enemy_formation.center_x < self.player.center_x:
|
||||||
if self.steps_since_direction_change < 3:
|
self.enemy_formation.move(self.width, self.height, "x", self.enemy_move_speed)
|
||||||
reward -= 0.1
|
elif self.enemy_formation.center_x > self.player.center_x:
|
||||||
|
self.enemy_formation.move(self.width, self.height, "x", -self.enemy_move_speed)
|
||||||
|
|
||||||
self.steps_since_direction_change = 0
|
if random.random() < 0.02:
|
||||||
|
if random.random() < 0.5:
|
||||||
|
self.enemy_formation.move(self.width, self.height, "y", -self.enemy_move_speed)
|
||||||
else:
|
else:
|
||||||
self.steps_since_direction_change += 1
|
self.enemy_formation.move(self.width, self.height, "y", self.enemy_move_speed)
|
||||||
self.last_direction = current_action_dir
|
|
||||||
|
|
||||||
if enemy_x != 2.0:
|
bullets_to_remove = []
|
||||||
if abs(enemy_x) < 0.03:
|
|
||||||
reward += 0.1
|
|
||||||
elif abs(enemy_x) < 0.08:
|
|
||||||
reward += 0.05
|
|
||||||
|
|
||||||
for b in list(self.bullets):
|
for b in self.bullets:
|
||||||
b.center_y += b.direction_y * BULLET_SPEED
|
b.center_y += b.direction_y * BULLET_SPEED
|
||||||
if b.center_y > self.height or b.center_y < 0:
|
|
||||||
try:
|
|
||||||
self.bullets.remove(b)
|
|
||||||
except ValueError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
for b in list(self.bullets):
|
if b.center_y > self.height or b.center_y < 0:
|
||||||
|
bullets_to_remove.append(b)
|
||||||
|
continue
|
||||||
|
|
||||||
if b.direction_y == 1:
|
if b.direction_y == 1:
|
||||||
for e in list(self.enemies):
|
for e in self.enemy_formation.enemies:
|
||||||
if arcade.check_for_collision(b, e):
|
if arcade.check_for_collision(b, e):
|
||||||
try:
|
self.enemy_formation.remove_enemy(e)
|
||||||
self.enemies.remove(e)
|
bullets_to_remove.append(b)
|
||||||
except ValueError:
|
reward += 10.0
|
||||||
pass
|
self.enemies_killed += 1
|
||||||
try:
|
|
||||||
self.bullets.remove(b)
|
|
||||||
except ValueError:
|
|
||||||
pass
|
|
||||||
reward += 1.0
|
|
||||||
break
|
break
|
||||||
|
|
||||||
for b in list(self.bullets):
|
elif b.direction_y == -1 and self.player_alive:
|
||||||
if b.direction_y == -1:
|
|
||||||
if arcade.check_for_collision(b, self.player):
|
if arcade.check_for_collision(b, self.player):
|
||||||
try:
|
bullets_to_remove.append(b)
|
||||||
|
reward -= 10.0
|
||||||
|
self.player_alive = False
|
||||||
|
|
||||||
|
if self.player_respawns_remaining > 0:
|
||||||
|
self.player_respawns_remaining -= 1
|
||||||
|
self._respawn_player()
|
||||||
|
reward += 2.0
|
||||||
|
else:
|
||||||
|
terminated = True
|
||||||
|
|
||||||
|
for b in bullets_to_remove:
|
||||||
|
if b in self.bullets:
|
||||||
self.bullets.remove(b)
|
self.bullets.remove(b)
|
||||||
except ValueError:
|
|
||||||
pass
|
if self.player_alive:
|
||||||
reward -= 5.0
|
lowest_enemy = self._lowest_enemy()
|
||||||
|
if lowest_enemy and lowest_enemy.center_y <= self.player.center_y:
|
||||||
|
reward -= 10.0
|
||||||
|
self.player_alive = False
|
||||||
|
|
||||||
|
if self.player_respawns_remaining > 0:
|
||||||
|
self.player_respawns_remaining -= 1
|
||||||
|
self._respawn_player()
|
||||||
|
else:
|
||||||
terminated = True
|
terminated = True
|
||||||
|
|
||||||
if not self.enemies:
|
if not self.enemy_formation.enemies:
|
||||||
reward += 10.0
|
reward += 50.0
|
||||||
|
|
||||||
|
if self.enemy_respawns_remaining > 0:
|
||||||
|
self.enemy_respawns_remaining -= 1
|
||||||
|
self._respawn_enemies()
|
||||||
|
reward += 20.0
|
||||||
|
else:
|
||||||
|
reward += 100.0
|
||||||
terminated = True
|
terminated = True
|
||||||
|
|
||||||
if self.enemies and random.random() < 0.05:
|
shooting_prob = 0.05 + (0.05 * (1.0 - len(self.enemy_formation.enemies) / (self.difficulty_settings["enemy_rows"] * self.difficulty_settings["enemy_cols"])))
|
||||||
e = random.choice(self.enemies)
|
if self.enemy_formation.enemies and random.random() < shooting_prob:
|
||||||
b = Bullet(e.center_x, e.center_y, -1)
|
enemy = self.enemy_formation.get_lowest_enemy()
|
||||||
|
if enemy:
|
||||||
|
b = Bullet(enemy.center_x, enemy.center_y, -1)
|
||||||
self.bullets.append(b)
|
self.bullets.append(b)
|
||||||
|
|
||||||
curr_bullet = self._nearest_enemy_bullet()
|
if self.player_alive:
|
||||||
if curr_bullet is not None:
|
edge_threshold = self.width * 0.15
|
||||||
curr_bx = (curr_bullet.center_x - self.player.center_x) / float(self.width)
|
if self.player.center_x < edge_threshold or self.player.center_x > self.width - edge_threshold:
|
||||||
else:
|
reward -= 0.03
|
||||||
curr_bx = 2.0
|
|
||||||
|
|
||||||
if self.prev_bx != 2.0 and curr_bx != 2.0:
|
reward -= 0.005
|
||||||
if abs(curr_bx) > abs(self.prev_bx):
|
|
||||||
reward += 0.02
|
|
||||||
|
|
||||||
reward -= 0.01
|
|
||||||
|
|
||||||
obs = self._obs()
|
obs = self._obs()
|
||||||
self.prev_bx = curr_bx
|
|
||||||
|
|
||||||
return obs, float(reward), bool(terminated), bool(truncated), {}
|
return obs, float(reward), bool(terminated), bool(truncated), {
|
||||||
|
"enemies_killed": self.enemies_killed,
|
||||||
|
"step": self.current_step,
|
||||||
|
"player_respawns_remaining": self.player_respawns_remaining,
|
||||||
|
"enemy_respawns_remaining": self.enemy_respawns_remaining
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user