mirror of
https://github.com/csd4ni3l/fleet-commander.git
synced 2026-01-01 04:23:47 +01:00
Add 10 million timestep model, improve README, add diffculty/mode selector, score, make model training have instant graphs and also multiple envs for faster training, better plotting, improve RL model by including multiple players, better reard system, use EnemyFormation instead of single Enemy-es
This commit is contained in:
136
game/sprites.py
136
game/sprites.py
@@ -1,10 +1,8 @@
|
||||
import arcade, time
|
||||
import arcade, time, random, numpy as np
|
||||
|
||||
from stable_baselines3 import PPO
|
||||
|
||||
import numpy as np
|
||||
|
||||
from utils.constants import PLAYER_SPEED, BULLET_SPEED, BULLET_RADIUS, PLAYER_ATTACK_SPEED, ENEMY_COLS, ENEMY_ROWS
|
||||
from utils.constants import PLAYER_SPEED, BULLET_SPEED, BULLET_RADIUS, PLAYER_ATTACK_SPEED
|
||||
from utils.preload import player_texture, enemy_texture
|
||||
|
||||
class Bullet(arcade.Sprite):
|
||||
@@ -16,9 +14,100 @@ class Bullet(arcade.Sprite):
|
||||
def update(self):
|
||||
self.center_y += self.direction_y * BULLET_SPEED
|
||||
|
||||
class Enemy(arcade.Sprite):
|
||||
def __init__(self, x, y):
|
||||
super().__init__(enemy_texture, center_x=x, center_y=y)
|
||||
class EnemyFormation():
|
||||
def __init__(self, start_x, start_y, spritelist: arcade.SpriteList | None, rows, cols):
|
||||
self.grid = [[] for _ in range(rows)]
|
||||
self.start_x = start_x
|
||||
self.start_y = start_y
|
||||
self.rows = rows
|
||||
self.cols = cols
|
||||
self.spritelist = spritelist
|
||||
|
||||
self.create_formation()
|
||||
|
||||
def create_formation(self, start_x=None, start_y=None):
|
||||
if start_x:
|
||||
self.start_x = start_x
|
||||
if start_y:
|
||||
self.start_y = start_y
|
||||
|
||||
del self.grid
|
||||
self.grid = [[] for _ in range(self.rows)]
|
||||
|
||||
for row in range(self.rows):
|
||||
for col in range(self.cols):
|
||||
enemy_sprite = arcade.Sprite(enemy_texture, center_x=self.start_x + col * 100, center_y=self.start_y - row * 100)
|
||||
|
||||
if self.spritelist:
|
||||
self.spritelist.append(enemy_sprite)
|
||||
|
||||
self.grid[row].append(enemy_sprite)
|
||||
|
||||
def remove_enemy(self, enemy):
|
||||
if self.spritelist and enemy not in self.spritelist:
|
||||
return
|
||||
|
||||
for row in range(self.rows):
|
||||
for col in range(self.cols):
|
||||
if self.grid[row][col] == enemy:
|
||||
self.grid[row][col] = None
|
||||
if self.spritelist:
|
||||
self.spritelist.remove(enemy)
|
||||
return
|
||||
|
||||
def get_lowest_enemy(self):
|
||||
valid_cols = []
|
||||
|
||||
for col in range(self.cols):
|
||||
row = self.rows - 1
|
||||
|
||||
while row >= 0 and self.grid[row][col] is None:
|
||||
row -= 1
|
||||
|
||||
if row >= 0:
|
||||
valid_cols.append((col, row))
|
||||
|
||||
if not valid_cols:
|
||||
return None
|
||||
|
||||
col, row = random.choice(valid_cols)
|
||||
return self.grid[row][col]
|
||||
|
||||
def move(self, width, height, direction_type, value):
|
||||
if direction_type == "x":
|
||||
wall_hit = False
|
||||
for enemy in self.enemies:
|
||||
self.start_x += value
|
||||
enemy.center_x += value
|
||||
|
||||
if enemy.center_x + enemy.width / 2 > width or enemy.center_x < enemy.width / 2:
|
||||
wall_hit = True
|
||||
|
||||
if wall_hit:
|
||||
for enemy in self.enemies:
|
||||
self.start_x -= value
|
||||
enemy.center_x -= value
|
||||
else:
|
||||
wall_hit = False
|
||||
for enemy in self.enemies:
|
||||
self.start_x += value
|
||||
enemy.center_y += value
|
||||
|
||||
if enemy.center_y + enemy.height / 2 > height or enemy.center_y < enemy.height / 2:
|
||||
wall_hit = True
|
||||
|
||||
if wall_hit:
|
||||
for enemy in self.enemies:
|
||||
self.start_y -= value
|
||||
enemy.center_y -= value
|
||||
|
||||
@property
|
||||
def center_x(self):
|
||||
return self.start_x + (self.cols / 2) * 100
|
||||
|
||||
@property
|
||||
def enemies(self):
|
||||
return [col for row in self.grid for col in row if not col == None]
|
||||
|
||||
class Player(arcade.Sprite): # Not actually the player
|
||||
def __init__(self, x, y):
|
||||
@@ -26,21 +115,19 @@ class Player(arcade.Sprite): # Not actually the player
|
||||
|
||||
self.last_target_change = time.perf_counter()
|
||||
self.last_shoot = time.perf_counter()
|
||||
self.target = None
|
||||
self.shoot = False
|
||||
|
||||
self.player_speed = 0
|
||||
|
||||
def update(self, model: PPO, enemies, bullets, width, height):
|
||||
if enemies:
|
||||
nearest_enemy = min(enemies, key=lambda e: abs(e.center_x - self.center_x))
|
||||
def update(self, model: PPO, enemy_formation, bullets, width, height, player_respawns_norm, enemy_respawns_norm):
|
||||
if enemy_formation.enemies:
|
||||
nearest_enemy = min(enemy_formation.enemies, key=lambda e: abs(e.center_x - self.center_x))
|
||||
enemy_x = (nearest_enemy.center_x - self.center_x) / width
|
||||
enemy_y = (nearest_enemy.center_y - self.center_y) / height
|
||||
else:
|
||||
enemy_x = 2
|
||||
enemy_y = 2
|
||||
|
||||
enemy_count = len(enemies) / float(max(1, ENEMY_ROWS * ENEMY_COLS))
|
||||
enemy_count = len(enemy_formation.enemies) / float(max(1, enemy_formation.rows * enemy_formation.cols))
|
||||
player_x_norm = self.center_x / width
|
||||
|
||||
curr_bullet = min(bullets, key=lambda b: abs(b.center_x - self.center_x) + abs(b.center_y - self.center_y)) if bullets else None
|
||||
@@ -51,18 +138,31 @@ class Player(arcade.Sprite): # Not actually the player
|
||||
curr_bx = 2.0
|
||||
curr_by = 2.0
|
||||
|
||||
lowest = max(enemies, key=lambda e: e.center_y) if enemies else None
|
||||
lowest = max(enemy_formation.enemies, key=lambda e: e.center_y) if enemy_formation.enemies else None
|
||||
if lowest is not None:
|
||||
lowest_dy = (lowest.center_y - self.center_y) / float(height)
|
||||
else:
|
||||
lowest_dy = 2.0
|
||||
|
||||
enemy_dispersion = 0.0
|
||||
if enemies:
|
||||
xs = np.array([e.center_x for e in enemies], dtype=np.float32)
|
||||
if enemy_formation.enemies:
|
||||
xs = np.array([e.center_x for e in enemy_formation.enemies], dtype=np.float32)
|
||||
enemy_dispersion = float(xs.std()) / float(width)
|
||||
|
||||
obs = np.array([player_x_norm, enemy_x, enemy_y, lowest_dy, curr_bx, curr_by, self.player_speed, enemy_count, enemy_dispersion], dtype=np.float32)
|
||||
obs = np.array([
|
||||
player_x_norm,
|
||||
enemy_x,
|
||||
enemy_y,
|
||||
lowest_dy,
|
||||
curr_bx,
|
||||
curr_by,
|
||||
self.player_speed,
|
||||
enemy_count,
|
||||
enemy_dispersion,
|
||||
time.perf_counter() - self.last_shoot >= PLAYER_ATTACK_SPEED,
|
||||
player_respawns_norm,
|
||||
enemy_respawns_norm
|
||||
], dtype=np.float32)
|
||||
action, _ = model.predict(obs, deterministic=True)
|
||||
|
||||
self.prev_x = self.center_x
|
||||
@@ -71,6 +171,8 @@ class Player(arcade.Sprite): # Not actually the player
|
||||
elif action == 1:
|
||||
self.center_x += PLAYER_SPEED
|
||||
elif action == 2:
|
||||
pass
|
||||
elif action == 3:
|
||||
t = time.perf_counter()
|
||||
if t - self.last_shoot >= PLAYER_ATTACK_SPEED:
|
||||
self.last_shoot = t
|
||||
|
||||
Reference in New Issue
Block a user