mirror of
https://github.com/csd4ni3l/fleet-commander.git
synced 2026-01-01 04:23:47 +01:00
Add 10 million timestep model, improve README, add diffculty/mode selector, score, make model training have instant graphs and also multiple envs for faster training, better plotting, improve RL model by including multiple players, better reard system, use EnemyFormation instead of single Enemy-es
This commit is contained in:
118
game/play.py
118
game/play.py
@@ -1,16 +1,17 @@
|
||||
import arcade, arcade.gui, random, time
|
||||
|
||||
from utils.constants import button_style, ENEMY_ROWS, ENEMY_COLS, PLAYER_ATTACK_SPEED
|
||||
from utils.constants import button_style, ENEMY_ATTACK_SPEED, ENEMY_SPEED
|
||||
from utils.preload import button_texture, button_hovered_texture
|
||||
|
||||
from stable_baselines3 import PPO
|
||||
|
||||
from game.sprites import Enemy, Player, Bullet
|
||||
from game.sprites import EnemyFormation, Player, Bullet
|
||||
|
||||
class Game(arcade.gui.UIView):
|
||||
def __init__(self, pypresence_client):
|
||||
def __init__(self, pypresence_client, settings):
|
||||
super().__init__()
|
||||
|
||||
self.settings = settings
|
||||
self.pypresence_client = pypresence_client
|
||||
self.pypresence_client.update(state="Invading Space")
|
||||
|
||||
@@ -18,42 +19,53 @@ class Game(arcade.gui.UIView):
|
||||
|
||||
self.spritelist = arcade.SpriteList()
|
||||
|
||||
self.player = Player(self.window.width / 2 + random.randint(int(-self.window.width / 3), int(self.window.width / 3)), 100) # not actually player
|
||||
self.spritelist.append(self.player)
|
||||
|
||||
self.last_player_shoot = time.perf_counter() # not actually player
|
||||
self.players = []
|
||||
for _ in range(settings["player_count"]):
|
||||
self.players.append(Player(self.window.width / 2 + random.randint(int(-self.window.width / 3), int(self.window.width / 3)), 100)) # not actually player
|
||||
self.spritelist.append(self.players[-1])
|
||||
|
||||
self.model = PPO.load("invader_agent.zip")
|
||||
|
||||
self.enemies: list[Enemy] = []
|
||||
self.enemy_formation = EnemyFormation(self.window.width / 2 + random.randint(int(-self.window.width / 3), int(self.window.width / 3)), self.window.height * 0.9, self.spritelist, settings["enemy_rows"], settings["enemy_cols"])
|
||||
self.player_bullets: list[Bullet] = []
|
||||
self.enemy_bullets: list[Bullet] = []
|
||||
|
||||
self.summon_enemies()
|
||||
self.player_respawns = settings["player_respawns"]
|
||||
self.enemy_respawns = settings["enemy_respawns"]
|
||||
|
||||
self.score = 0
|
||||
|
||||
self.last_enemy_shoot = time.perf_counter()
|
||||
|
||||
self.game_over = False
|
||||
|
||||
def on_show_view(self):
|
||||
super().on_show_view()
|
||||
|
||||
self.back_button = self.anchor.add(arcade.gui.UITextureButton(texture=button_texture, texture_hovered=button_hovered_texture, text='<--', style=button_style, width=100, height=50), anchor_x="left", anchor_y="top", align_x=5, align_y=-5)
|
||||
self.back_button.on_click = lambda event: self.main_exit()
|
||||
self.score_label = self.anchor.add(arcade.gui.UILabel("Score: 0", font_size=24), anchor_x="center", anchor_y="top")
|
||||
|
||||
def main_exit(self):
|
||||
from menus.main import Main
|
||||
self.window.show_view(Main(self.pypresence_client))
|
||||
|
||||
def summon_enemies(self):
|
||||
enemy_start_x = self.window.width * 0.15
|
||||
enemy_start_y = self.window.height * 0.9
|
||||
|
||||
for row in range(ENEMY_ROWS):
|
||||
for col in range(ENEMY_COLS):
|
||||
enemy_sprite = Enemy(enemy_start_x + col * 100, enemy_start_y - row * 100)
|
||||
self.spritelist.append(enemy_sprite)
|
||||
self.enemies.append(enemy_sprite)
|
||||
|
||||
def on_update(self, delta_time):
|
||||
for enemy in self.enemies:
|
||||
enemy.update()
|
||||
if self.game_over:
|
||||
return
|
||||
|
||||
if self.window.keyboard[arcade.key.LEFT] or self.window.keyboard[arcade.key.A]:
|
||||
self.enemy_formation.move(self.window.width, self.window.height, "x", -ENEMY_SPEED)
|
||||
if self.window.keyboard[arcade.key.RIGHT] or self.window.keyboard[arcade.key.D]:
|
||||
self.enemy_formation.move(self.window.width, self.window.height, "x", ENEMY_SPEED)
|
||||
if self.window.keyboard[arcade.key.DOWN] or self.window.keyboard[arcade.key.S]:
|
||||
self.enemy_formation.move(self.window.width, self.window.height, "y", -ENEMY_SPEED)
|
||||
if self.window.keyboard[arcade.key.UP] or self.window.keyboard[arcade.key.W]:
|
||||
self.enemy_formation.move(self.window.width, self.window.height, "y", ENEMY_SPEED)
|
||||
if self.enemy_formation.enemies and self.window.keyboard[arcade.key.SPACE] and time.perf_counter() - self.last_enemy_shoot >= ENEMY_ATTACK_SPEED:
|
||||
self.last_enemy_shoot = time.perf_counter()
|
||||
enemy = self.enemy_formation.get_lowest_enemy()
|
||||
self.shoot(enemy.center_x, enemy.center_y, -1)
|
||||
|
||||
bullets_to_remove = []
|
||||
|
||||
@@ -62,18 +74,21 @@ class Game(arcade.gui.UIView):
|
||||
|
||||
bullet_hit = False
|
||||
if bullet.direction_y == 1:
|
||||
for enemy in self.enemies:
|
||||
for enemy in self.enemy_formation.enemies:
|
||||
if bullet.rect.intersection(enemy.rect):
|
||||
self.spritelist.remove(enemy)
|
||||
self.enemies.remove(enemy)
|
||||
self.player.target = None
|
||||
self.enemy_formation.remove_enemy(enemy)
|
||||
bullets_to_remove.append(bullet)
|
||||
bullet_hit = True
|
||||
break
|
||||
else:
|
||||
if bullet.rect.intersection(self.player.rect):
|
||||
bullets_to_remove.append(bullet)
|
||||
bullet_hit = True
|
||||
for player in self.players:
|
||||
if bullet.rect.intersection(player.rect):
|
||||
self.spritelist.remove(player)
|
||||
self.players.remove(player)
|
||||
bullets_to_remove.append(bullet)
|
||||
bullet_hit = True
|
||||
self.score += 75
|
||||
break
|
||||
|
||||
if not bullet_hit and bullet.center_y > self.window.height or bullet.center_y < 0:
|
||||
bullets_to_remove.append(bullet)
|
||||
@@ -86,20 +101,40 @@ class Game(arcade.gui.UIView):
|
||||
elif bullet_to_remove in self.player_bullets:
|
||||
self.player_bullets.remove(bullet_to_remove)
|
||||
|
||||
self.player.update(self.model, self.enemies, self.enemy_bullets, self.window.width, self.window.height) # not actually player
|
||||
for player in self.players:
|
||||
player.update(self.model, self.enemy_formation, self.enemy_bullets, self.window.width, self.window.height, self.player_respawns / self.settings["player_respawns"], self.enemy_respawns / self.settings["enemy_respawns"]) # not actually player
|
||||
|
||||
if self.player.center_x > self.window.width:
|
||||
self.player.center_x = self.window.width
|
||||
elif self.player.center_x < 0:
|
||||
self.player.center_x = 0
|
||||
if player.center_x > self.window.width:
|
||||
player.center_x = self.window.width
|
||||
elif player.center_x < 0:
|
||||
player.center_x = 0
|
||||
|
||||
if self.player.shoot:
|
||||
self.player.shoot = False
|
||||
self.shoot(self.player.center_x, self.player.center_y, 1)
|
||||
if player.shoot:
|
||||
player.shoot = False
|
||||
self.shoot(player.center_x, player.center_y, 1)
|
||||
|
||||
if time.perf_counter() - self.last_player_shoot >= PLAYER_ATTACK_SPEED:
|
||||
self.last_player_shoot = time.perf_counter()
|
||||
self.shoot(self.player.center_x, self.player.center_y, 1)
|
||||
if len(self.players) == 0:
|
||||
if self.player_respawns > 0:
|
||||
self.player_respawns -= 1
|
||||
for _ in range(self.settings["player_count"]):
|
||||
self.players.append(Player(self.window.width / 2 + random.randint(int(-self.window.width / 3), int(self.window.width / 3)), 100)) # not actually player
|
||||
self.spritelist.append(self.players[-1])
|
||||
self.score += 300
|
||||
else:
|
||||
self.game_over = True
|
||||
self.game_over_label = self.anchor.add(arcade.gui.UILabel("You (The enemies) won!", font_size=48), anchor_x="center", anchor_y="center")
|
||||
|
||||
elif len(self.enemy_formation.enemies) == 0:
|
||||
if self.enemy_respawns > 0:
|
||||
self.enemy_respawns -= 1
|
||||
self.enemy_formation.create_formation(self.window.width / 2 + random.randint(int(-self.window.width / 3), int(self.window.width / 3)), self.window.height * 0.9)
|
||||
else:
|
||||
self.game_over = True
|
||||
self.game_over_label = self.anchor.add(arcade.gui.UILabel("You lost! The Players win!", font_size=48), anchor_x="center", anchor_y="center")
|
||||
|
||||
self.score += 5 * delta_time
|
||||
|
||||
self.score_label.text = f"Score: {int(self.score)}"
|
||||
|
||||
def shoot(self, x, y, direction_y):
|
||||
bullet = Bullet(x, y, direction_y)
|
||||
@@ -112,11 +147,6 @@ class Game(arcade.gui.UIView):
|
||||
|
||||
bullets.append(bullet)
|
||||
|
||||
def on_key_press(self, symbol, modifiers):
|
||||
if symbol == arcade.key.SPACE:
|
||||
enemy = random.choice(self.enemies)
|
||||
self.shoot(enemy.center_x, enemy.center_y, -1)
|
||||
|
||||
def on_draw(self):
|
||||
super().on_draw()
|
||||
|
||||
|
||||
136
game/sprites.py
136
game/sprites.py
@@ -1,10 +1,8 @@
|
||||
import arcade, time
|
||||
import arcade, time, random, numpy as np
|
||||
|
||||
from stable_baselines3 import PPO
|
||||
|
||||
import numpy as np
|
||||
|
||||
from utils.constants import PLAYER_SPEED, BULLET_SPEED, BULLET_RADIUS, PLAYER_ATTACK_SPEED, ENEMY_COLS, ENEMY_ROWS
|
||||
from utils.constants import PLAYER_SPEED, BULLET_SPEED, BULLET_RADIUS, PLAYER_ATTACK_SPEED
|
||||
from utils.preload import player_texture, enemy_texture
|
||||
|
||||
class Bullet(arcade.Sprite):
|
||||
@@ -16,9 +14,100 @@ class Bullet(arcade.Sprite):
|
||||
def update(self):
|
||||
self.center_y += self.direction_y * BULLET_SPEED
|
||||
|
||||
class Enemy(arcade.Sprite):
|
||||
def __init__(self, x, y):
|
||||
super().__init__(enemy_texture, center_x=x, center_y=y)
|
||||
class EnemyFormation():
|
||||
def __init__(self, start_x, start_y, spritelist: arcade.SpriteList | None, rows, cols):
|
||||
self.grid = [[] for _ in range(rows)]
|
||||
self.start_x = start_x
|
||||
self.start_y = start_y
|
||||
self.rows = rows
|
||||
self.cols = cols
|
||||
self.spritelist = spritelist
|
||||
|
||||
self.create_formation()
|
||||
|
||||
def create_formation(self, start_x=None, start_y=None):
|
||||
if start_x:
|
||||
self.start_x = start_x
|
||||
if start_y:
|
||||
self.start_y = start_y
|
||||
|
||||
del self.grid
|
||||
self.grid = [[] for _ in range(self.rows)]
|
||||
|
||||
for row in range(self.rows):
|
||||
for col in range(self.cols):
|
||||
enemy_sprite = arcade.Sprite(enemy_texture, center_x=self.start_x + col * 100, center_y=self.start_y - row * 100)
|
||||
|
||||
if self.spritelist:
|
||||
self.spritelist.append(enemy_sprite)
|
||||
|
||||
self.grid[row].append(enemy_sprite)
|
||||
|
||||
def remove_enemy(self, enemy):
|
||||
if self.spritelist and enemy not in self.spritelist:
|
||||
return
|
||||
|
||||
for row in range(self.rows):
|
||||
for col in range(self.cols):
|
||||
if self.grid[row][col] == enemy:
|
||||
self.grid[row][col] = None
|
||||
if self.spritelist:
|
||||
self.spritelist.remove(enemy)
|
||||
return
|
||||
|
||||
def get_lowest_enemy(self):
|
||||
valid_cols = []
|
||||
|
||||
for col in range(self.cols):
|
||||
row = self.rows - 1
|
||||
|
||||
while row >= 0 and self.grid[row][col] is None:
|
||||
row -= 1
|
||||
|
||||
if row >= 0:
|
||||
valid_cols.append((col, row))
|
||||
|
||||
if not valid_cols:
|
||||
return None
|
||||
|
||||
col, row = random.choice(valid_cols)
|
||||
return self.grid[row][col]
|
||||
|
||||
def move(self, width, height, direction_type, value):
|
||||
if direction_type == "x":
|
||||
wall_hit = False
|
||||
for enemy in self.enemies:
|
||||
self.start_x += value
|
||||
enemy.center_x += value
|
||||
|
||||
if enemy.center_x + enemy.width / 2 > width or enemy.center_x < enemy.width / 2:
|
||||
wall_hit = True
|
||||
|
||||
if wall_hit:
|
||||
for enemy in self.enemies:
|
||||
self.start_x -= value
|
||||
enemy.center_x -= value
|
||||
else:
|
||||
wall_hit = False
|
||||
for enemy in self.enemies:
|
||||
self.start_x += value
|
||||
enemy.center_y += value
|
||||
|
||||
if enemy.center_y + enemy.height / 2 > height or enemy.center_y < enemy.height / 2:
|
||||
wall_hit = True
|
||||
|
||||
if wall_hit:
|
||||
for enemy in self.enemies:
|
||||
self.start_y -= value
|
||||
enemy.center_y -= value
|
||||
|
||||
@property
|
||||
def center_x(self):
|
||||
return self.start_x + (self.cols / 2) * 100
|
||||
|
||||
@property
|
||||
def enemies(self):
|
||||
return [col for row in self.grid for col in row if not col == None]
|
||||
|
||||
class Player(arcade.Sprite): # Not actually the player
|
||||
def __init__(self, x, y):
|
||||
@@ -26,21 +115,19 @@ class Player(arcade.Sprite): # Not actually the player
|
||||
|
||||
self.last_target_change = time.perf_counter()
|
||||
self.last_shoot = time.perf_counter()
|
||||
self.target = None
|
||||
self.shoot = False
|
||||
|
||||
self.player_speed = 0
|
||||
|
||||
def update(self, model: PPO, enemies, bullets, width, height):
|
||||
if enemies:
|
||||
nearest_enemy = min(enemies, key=lambda e: abs(e.center_x - self.center_x))
|
||||
def update(self, model: PPO, enemy_formation, bullets, width, height, player_respawns_norm, enemy_respawns_norm):
|
||||
if enemy_formation.enemies:
|
||||
nearest_enemy = min(enemy_formation.enemies, key=lambda e: abs(e.center_x - self.center_x))
|
||||
enemy_x = (nearest_enemy.center_x - self.center_x) / width
|
||||
enemy_y = (nearest_enemy.center_y - self.center_y) / height
|
||||
else:
|
||||
enemy_x = 2
|
||||
enemy_y = 2
|
||||
|
||||
enemy_count = len(enemies) / float(max(1, ENEMY_ROWS * ENEMY_COLS))
|
||||
enemy_count = len(enemy_formation.enemies) / float(max(1, enemy_formation.rows * enemy_formation.cols))
|
||||
player_x_norm = self.center_x / width
|
||||
|
||||
curr_bullet = min(bullets, key=lambda b: abs(b.center_x - self.center_x) + abs(b.center_y - self.center_y)) if bullets else None
|
||||
@@ -51,18 +138,31 @@ class Player(arcade.Sprite): # Not actually the player
|
||||
curr_bx = 2.0
|
||||
curr_by = 2.0
|
||||
|
||||
lowest = max(enemies, key=lambda e: e.center_y) if enemies else None
|
||||
lowest = max(enemy_formation.enemies, key=lambda e: e.center_y) if enemy_formation.enemies else None
|
||||
if lowest is not None:
|
||||
lowest_dy = (lowest.center_y - self.center_y) / float(height)
|
||||
else:
|
||||
lowest_dy = 2.0
|
||||
|
||||
enemy_dispersion = 0.0
|
||||
if enemies:
|
||||
xs = np.array([e.center_x for e in enemies], dtype=np.float32)
|
||||
if enemy_formation.enemies:
|
||||
xs = np.array([e.center_x for e in enemy_formation.enemies], dtype=np.float32)
|
||||
enemy_dispersion = float(xs.std()) / float(width)
|
||||
|
||||
obs = np.array([player_x_norm, enemy_x, enemy_y, lowest_dy, curr_bx, curr_by, self.player_speed, enemy_count, enemy_dispersion], dtype=np.float32)
|
||||
obs = np.array([
|
||||
player_x_norm,
|
||||
enemy_x,
|
||||
enemy_y,
|
||||
lowest_dy,
|
||||
curr_bx,
|
||||
curr_by,
|
||||
self.player_speed,
|
||||
enemy_count,
|
||||
enemy_dispersion,
|
||||
time.perf_counter() - self.last_shoot >= PLAYER_ATTACK_SPEED,
|
||||
player_respawns_norm,
|
||||
enemy_respawns_norm
|
||||
], dtype=np.float32)
|
||||
action, _ = model.predict(obs, deterministic=True)
|
||||
|
||||
self.prev_x = self.center_x
|
||||
@@ -71,6 +171,8 @@ class Player(arcade.Sprite): # Not actually the player
|
||||
elif action == 1:
|
||||
self.center_x += PLAYER_SPEED
|
||||
elif action == 2:
|
||||
pass
|
||||
elif action == 3:
|
||||
t = time.perf_counter()
|
||||
if t - self.last_shoot >= PLAYER_ATTACK_SPEED:
|
||||
self.last_shoot = t
|
||||
|
||||
Reference in New Issue
Block a user