mirror of
https://github.com/csd4ni3l/fleet-commander.git
synced 2026-01-01 04:23:47 +01:00
Add model training with graphs and current stats, improve model with better rewarding system
This commit is contained in:
227
utils/rl.py
Normal file
227
utils/rl.py
Normal file
@@ -0,0 +1,227 @@
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
import arcade
|
||||
import time
|
||||
import random
|
||||
|
||||
from game.sprites import Enemy, Player, Bullet
|
||||
from utils.constants import PLAYER_SPEED, BULLET_SPEED, PLAYER_ATTACK_SPEED, ENEMY_ROWS, ENEMY_COLS
|
||||
|
||||
class SpaceInvadersEnv(gym.Env):
|
||||
def __init__(self, width=800, height=600):
|
||||
self.width = width
|
||||
self.height = height
|
||||
|
||||
self.action_space = gym.spaces.Discrete(3)
|
||||
self.observation_space = gym.spaces.Box(low=-2.0, high=2.0, shape=(9,), dtype=np.float32)
|
||||
|
||||
self.enemies = []
|
||||
self.bullets = []
|
||||
self.dir_history = []
|
||||
self.last_shot = 0.0
|
||||
self.player = None
|
||||
self.prev_x = 0.0
|
||||
self.player_speed = 0.0
|
||||
self.prev_bx = 2.0
|
||||
self.steps_since_direction_change = 0
|
||||
self.last_direction = 0
|
||||
self.max_steps = 1000
|
||||
self.current_step = 0
|
||||
|
||||
def reset(self, seed=None, options=None):
|
||||
if seed is not None:
|
||||
np.random.seed(seed)
|
||||
random.seed(seed)
|
||||
|
||||
self.enemies = []
|
||||
self.bullets = []
|
||||
self.dir_history = []
|
||||
self.player = Player(self.width / 2 + random.randint(int(-self.width / 3), int(self.width / 3)), 100)
|
||||
self.prev_x = self.player.center_x
|
||||
self.player_speed = 0.0
|
||||
self.prev_bx = 2.0
|
||||
self.steps_since_direction_change = 0
|
||||
self.last_direction = 0
|
||||
self.current_step = 0
|
||||
|
||||
start_x = self.width * 0.15
|
||||
start_y = self.height * 0.9
|
||||
|
||||
for r in range(ENEMY_ROWS):
|
||||
for c in range(ENEMY_COLS):
|
||||
e = Enemy(start_x + c * 100, start_y - r * 100)
|
||||
self.enemies.append(e)
|
||||
|
||||
self.last_shot = time.perf_counter()
|
||||
return self._obs(), {}
|
||||
|
||||
def _nearest_enemy(self):
|
||||
if not self.enemies:
|
||||
return None
|
||||
return min(self.enemies, key=lambda e: abs(e.center_x - self.player.center_x))
|
||||
|
||||
def _lowest_enemy(self):
|
||||
if not self.enemies:
|
||||
return None
|
||||
return max(self.enemies, key=lambda e: e.center_y)
|
||||
|
||||
def _nearest_enemy_bullet(self):
|
||||
enemy_bullets = [b for b in self.bullets if b.direction_y == -1]
|
||||
if not enemy_bullets:
|
||||
return None
|
||||
return min(enemy_bullets, key=lambda b: abs(b.center_x - self.player.center_x) + abs(b.center_y - self.player.center_y))
|
||||
|
||||
def _obs(self):
|
||||
if self.enemies:
|
||||
nearest = self._nearest_enemy()
|
||||
enemy_x = (nearest.center_x - self.player.center_x) / float(self.width)
|
||||
enemy_y = (nearest.center_y - self.player.center_y) / float(self.height)
|
||||
else:
|
||||
enemy_x = 2.0
|
||||
enemy_y = 2.0
|
||||
|
||||
lowest = self._lowest_enemy()
|
||||
|
||||
if lowest is not None:
|
||||
lowest_dy = (lowest.center_y - self.player.center_y) / float(self.height)
|
||||
else:
|
||||
lowest_dy = 2.0
|
||||
|
||||
nb = self._nearest_enemy_bullet()
|
||||
if nb is not None:
|
||||
bx = (nb.center_x - self.player.center_x) / float(self.width)
|
||||
by = (nb.center_y - self.player.center_y) / float(self.height)
|
||||
else:
|
||||
bx = 2.0
|
||||
by = 2.0
|
||||
|
||||
enemy_count = len(self.enemies) / float(max(1, ENEMY_ROWS * ENEMY_COLS))
|
||||
player_x_norm = self.player.center_x / float(self.width)
|
||||
enemy_dispersion = 0.0
|
||||
|
||||
if self.enemies:
|
||||
xs = np.array([e.center_x for e in self.enemies], dtype=np.float32)
|
||||
enemy_dispersion = float(xs.std()) / float(self.width)
|
||||
|
||||
obs = np.array([player_x_norm, enemy_x, enemy_y, lowest_dy, bx, by, self.player_speed, enemy_count, enemy_dispersion], dtype=np.float32)
|
||||
return obs
|
||||
|
||||
def step(self, action):
|
||||
reward = 0.0
|
||||
terminated = False
|
||||
truncated = False
|
||||
|
||||
self.current_step += 1
|
||||
if self.current_step >= self.max_steps:
|
||||
truncated = True
|
||||
|
||||
nearest = self._nearest_enemy()
|
||||
if nearest is not None:
|
||||
enemy_x = (nearest.center_x - self.player.center_x) / float(self.width)
|
||||
else:
|
||||
enemy_x = 2.0
|
||||
|
||||
prev_x = self.player.center_x
|
||||
current_action_dir = 0
|
||||
|
||||
if action == 0:
|
||||
self.player.center_x -= PLAYER_SPEED
|
||||
current_action_dir = -1
|
||||
elif action == 1:
|
||||
self.player.center_x += PLAYER_SPEED
|
||||
current_action_dir = 1
|
||||
elif action == 2:
|
||||
t = time.perf_counter()
|
||||
if t - self.last_shot >= PLAYER_ATTACK_SPEED:
|
||||
self.last_shot = t
|
||||
|
||||
b = Bullet(self.player.center_x, self.player.center_y, 1)
|
||||
|
||||
self.bullets.append(b)
|
||||
|
||||
if enemy_x != 2.0 and abs(enemy_x) < 0.04:
|
||||
reward += 0.3
|
||||
elif enemy_x != 2.0 and abs(enemy_x) < 0.1:
|
||||
reward += 0.1
|
||||
|
||||
if self.player.center_x > self.width:
|
||||
self.player.center_x = self.width
|
||||
elif self.player.center_x < 0:
|
||||
self.player.center_x = 0
|
||||
|
||||
self.player_speed = (self.player.center_x - prev_x) / max(1e-6, PLAYER_SPEED)
|
||||
|
||||
if current_action_dir != 0:
|
||||
if self.last_direction != 0 and current_action_dir != self.last_direction:
|
||||
if self.steps_since_direction_change < 3:
|
||||
reward -= 0.1
|
||||
|
||||
self.steps_since_direction_change = 0
|
||||
else:
|
||||
self.steps_since_direction_change += 1
|
||||
self.last_direction = current_action_dir
|
||||
|
||||
if enemy_x != 2.0:
|
||||
if abs(enemy_x) < 0.03:
|
||||
reward += 0.1
|
||||
elif abs(enemy_x) < 0.08:
|
||||
reward += 0.05
|
||||
|
||||
for b in list(self.bullets):
|
||||
b.center_y += b.direction_y * BULLET_SPEED
|
||||
if b.center_y > self.height or b.center_y < 0:
|
||||
try:
|
||||
self.bullets.remove(b)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
for b in list(self.bullets):
|
||||
if b.direction_y == 1:
|
||||
for e in list(self.enemies):
|
||||
if arcade.check_for_collision(b, e):
|
||||
try:
|
||||
self.enemies.remove(e)
|
||||
except ValueError:
|
||||
pass
|
||||
try:
|
||||
self.bullets.remove(b)
|
||||
except ValueError:
|
||||
pass
|
||||
reward += 1.0
|
||||
break
|
||||
|
||||
for b in list(self.bullets):
|
||||
if b.direction_y == -1:
|
||||
if arcade.check_for_collision(b, self.player):
|
||||
try:
|
||||
self.bullets.remove(b)
|
||||
except ValueError:
|
||||
pass
|
||||
reward -= 5.0
|
||||
terminated = True
|
||||
|
||||
if not self.enemies:
|
||||
reward += 10.0
|
||||
terminated = True
|
||||
|
||||
if self.enemies and random.random() < 0.05:
|
||||
e = random.choice(self.enemies)
|
||||
b = Bullet(e.center_x, e.center_y, -1)
|
||||
self.bullets.append(b)
|
||||
|
||||
curr_bullet = self._nearest_enemy_bullet()
|
||||
if curr_bullet is not None:
|
||||
curr_bx = (curr_bullet.center_x - self.player.center_x) / float(self.width)
|
||||
else:
|
||||
curr_bx = 2.0
|
||||
|
||||
if self.prev_bx != 2.0 and curr_bx != 2.0:
|
||||
if abs(curr_bx) > abs(self.prev_bx):
|
||||
reward += 0.02
|
||||
|
||||
reward -= 0.01
|
||||
|
||||
obs = self._obs()
|
||||
self.prev_bx = curr_bx
|
||||
|
||||
return obs, float(reward), bool(terminated), bool(truncated), {}
|
||||
Reference in New Issue
Block a user