mirror of
https://github.com/csd4ni3l/fleet-commander.git
synced 2026-01-01 04:23:47 +01:00
276 lines
10 KiB
Python
276 lines
10 KiB
Python
import gymnasium as gym
|
|
import numpy as np
|
|
import arcade
|
|
import random
|
|
|
|
from game.sprites import EnemyFormation, Player, Bullet
|
|
from utils.constants import PLAYER_SPEED, BULLET_SPEED, ENEMY_SPEED, DIFFICULTY_LEVELS
|
|
|
|
class SpaceInvadersEnv(gym.Env):
|
|
def __init__(self, width=800, height=600, difficulty="Hard"):
|
|
self.width = width
|
|
self.height = height
|
|
|
|
self.action_space = gym.spaces.Discrete(4)
|
|
self.observation_space = gym.spaces.Box(low=-2.0, high=2.0, shape=(12,), dtype=np.float32)
|
|
|
|
if difficulty not in DIFFICULTY_LEVELS:
|
|
raise ValueError(f"Unknown difficulty: {difficulty}. Available: {list(DIFFICULTY_LEVELS.keys())}")
|
|
|
|
self.difficulty_settings = DIFFICULTY_LEVELS[difficulty]
|
|
|
|
self.bullets = []
|
|
self.player = None
|
|
self.enemy_formation = None
|
|
self.player_speed = 0.0
|
|
self.max_steps = 2000
|
|
self.current_step = 0
|
|
self.enemies_killed = 0
|
|
self.enemy_move_speed = ENEMY_SPEED
|
|
self.player_respawns = self.difficulty_settings["player_respawns"]
|
|
self.enemy_respawns = self.difficulty_settings["enemy_respawns"]
|
|
self.player_respawns_remaining = 0
|
|
self.enemy_respawns_remaining = 0
|
|
self.player_alive = True
|
|
|
|
self.player_attack_cooldown_steps = 5
|
|
self.current_cooldown = 0
|
|
|
|
def reset(self, seed=None, options=None):
|
|
if seed is not None:
|
|
np.random.seed(seed)
|
|
random.seed(seed)
|
|
|
|
self.bullets = []
|
|
self.player = Player(self.width / 2 + random.randint(int(-self.width / 3), int(self.width / 3)), 100)
|
|
self.player_speed = 0.0
|
|
self.current_step = 0
|
|
self.enemies_killed = 0
|
|
self.player_respawns_remaining = self.player_respawns
|
|
self.enemy_respawns_remaining = self.enemy_respawns
|
|
self.player_alive = True
|
|
self.current_cooldown = 0
|
|
|
|
start_x = self.width * 0.15
|
|
start_y = self.height * 0.9
|
|
|
|
self.enemy_formation = EnemyFormation(start_x, start_y, None,
|
|
self.difficulty_settings["enemy_rows"],
|
|
self.difficulty_settings["enemy_cols"])
|
|
|
|
return self._obs(), {}
|
|
|
|
def _nearest_enemy(self):
|
|
if not self.enemy_formation.enemies:
|
|
return None
|
|
|
|
return min(self.enemy_formation.enemies, key=lambda e: abs(e.center_x - self.player.center_x))
|
|
|
|
def _lowest_enemy(self):
|
|
if not self.enemy_formation.enemies:
|
|
return None
|
|
|
|
return min(self.enemy_formation.enemies, key=lambda e: e.center_y)
|
|
|
|
def _nearest_enemy_bullet(self):
|
|
enemy_bullets = [b for b in self.bullets if b.direction_y == -1]
|
|
|
|
if not enemy_bullets:
|
|
return None
|
|
|
|
return min(enemy_bullets, key=lambda b: abs(b.center_x - self.player.center_x) + abs(b.center_y - self.player.center_y))
|
|
|
|
def _obs(self):
|
|
if self.enemy_formation.enemies and self.player_alive:
|
|
nearest = self._nearest_enemy()
|
|
enemy_x = (nearest.center_x - self.player.center_x) / float(self.width)
|
|
enemy_y = (nearest.center_y - self.player.center_y) / float(self.height)
|
|
else:
|
|
enemy_x = 2.0
|
|
enemy_y = 2.0
|
|
|
|
lowest = self._lowest_enemy()
|
|
if lowest is not None and self.player_alive:
|
|
lowest_dy = (lowest.center_y - self.player.center_y) / float(self.height)
|
|
else:
|
|
lowest_dy = 2.0
|
|
|
|
nb = self._nearest_enemy_bullet()
|
|
if nb is not None and self.player_alive:
|
|
bx = (nb.center_x - self.player.center_x) / float(self.width)
|
|
by = (nb.center_y - self.player.center_y) / float(self.height)
|
|
else:
|
|
bx = 2.0
|
|
by = 2.0
|
|
|
|
enemy_count = len(self.enemy_formation.enemies) / float(max(1, self.difficulty_settings["enemy_rows"] * self.difficulty_settings["enemy_cols"]))
|
|
player_x_norm = self.player.center_x / float(self.width) if self.player_alive else 0.5
|
|
|
|
enemy_dispersion = 0.0
|
|
if self.enemy_formation.enemies:
|
|
xs = np.array([e.center_x for e in self.enemy_formation.enemies], dtype=np.float32)
|
|
enemy_dispersion = float(xs.std()) / float(self.width)
|
|
|
|
can_shoot = 1.0 if (self.player_alive and self.current_cooldown <= 0) else 0.0
|
|
|
|
player_respawns_norm = self.player_respawns_remaining / float(max(1, self.player_respawns))
|
|
enemy_respawns_norm = self.enemy_respawns_remaining / float(max(1, self.enemy_respawns))
|
|
|
|
obs = np.array([
|
|
player_x_norm,
|
|
enemy_x,
|
|
enemy_y,
|
|
lowest_dy,
|
|
bx,
|
|
by,
|
|
self.player_speed,
|
|
enemy_count,
|
|
enemy_dispersion,
|
|
can_shoot,
|
|
player_respawns_norm,
|
|
enemy_respawns_norm
|
|
], dtype=np.float32)
|
|
|
|
return obs
|
|
|
|
def _respawn_player(self):
|
|
self.player = Player(self.width / 2 + random.randint(int(-self.width / 3), int(self.width / 3)), 100)
|
|
self.player_alive = True
|
|
self.bullets = [b for b in self.bullets if b.direction_y == 1]
|
|
self.current_cooldown = 0
|
|
|
|
def _respawn_enemies(self):
|
|
self.enemy_formation.start_x = self.width * 0.15
|
|
self.enemy_formation.start_y = self.height * 0.9
|
|
self.enemy_formation.create_formation()
|
|
|
|
def step(self, action):
|
|
reward = 0.0
|
|
terminated = False
|
|
truncated = False
|
|
|
|
self.current_step += 1
|
|
if self.current_step >= self.max_steps:
|
|
truncated = True
|
|
|
|
if self.current_cooldown > 0:
|
|
self.current_cooldown -= 1
|
|
|
|
if self.player_alive:
|
|
prev_x = self.player.center_x
|
|
|
|
if action == 0:
|
|
self.player.center_x -= PLAYER_SPEED
|
|
elif action == 1:
|
|
self.player.center_x += PLAYER_SPEED
|
|
elif action == 2:
|
|
pass
|
|
elif action == 3:
|
|
if self.current_cooldown <= 0:
|
|
self.current_cooldown = self.player_attack_cooldown_steps
|
|
reward += 0.01
|
|
b = Bullet(self.player.center_x, self.player.center_y, 1)
|
|
self.bullets.append(b)
|
|
else:
|
|
reward -= 0.02
|
|
|
|
if self.enemy_formation.enemies:
|
|
nearest = self._nearest_enemy()
|
|
alignment = abs(nearest.center_x - self.player.center_x) / self.width
|
|
if alignment < 0.025:
|
|
reward += 0.005
|
|
|
|
self.player.center_x = np.clip(self.player.center_x, 0, self.width)
|
|
self.player_speed = (self.player.center_x - prev_x) / max(1e-6, PLAYER_SPEED)
|
|
|
|
if self.enemy_formation.enemies and self.player_alive:
|
|
if self.enemy_formation.center_x < self.player.center_x:
|
|
self.enemy_formation.move(self.width, self.height, "x", self.enemy_move_speed)
|
|
elif self.enemy_formation.center_x > self.player.center_x:
|
|
self.enemy_formation.move(self.width, self.height, "x", -self.enemy_move_speed)
|
|
|
|
if random.random() < 0.02:
|
|
if random.random() < 0.5:
|
|
self.enemy_formation.move(self.width, self.height, "y", -self.enemy_move_speed)
|
|
else:
|
|
self.enemy_formation.move(self.width, self.height, "y", self.enemy_move_speed)
|
|
|
|
bullets_to_remove = []
|
|
|
|
for b in self.bullets:
|
|
b.center_y += b.direction_y * BULLET_SPEED
|
|
|
|
if b.center_y > self.height or b.center_y < 0:
|
|
bullets_to_remove.append(b)
|
|
continue
|
|
|
|
if b.direction_y == 1:
|
|
for e in self.enemy_formation.enemies:
|
|
if arcade.check_for_collision(b, e):
|
|
self.enemy_formation.remove_enemy(e)
|
|
bullets_to_remove.append(b)
|
|
reward += 10.0
|
|
self.enemies_killed += 1
|
|
break
|
|
|
|
elif b.direction_y == -1 and self.player_alive:
|
|
if arcade.check_for_collision(b, self.player):
|
|
bullets_to_remove.append(b)
|
|
reward -= 10.0
|
|
self.player_alive = False
|
|
|
|
if self.player_respawns_remaining > 0:
|
|
self.player_respawns_remaining -= 1
|
|
self._respawn_player()
|
|
else:
|
|
terminated = True
|
|
|
|
for b in bullets_to_remove:
|
|
if b in self.bullets:
|
|
self.bullets.remove(b)
|
|
|
|
if self.player_alive:
|
|
lowest_enemy = self._lowest_enemy()
|
|
if lowest_enemy and lowest_enemy.center_y <= self.player.center_y:
|
|
reward -= 10.0
|
|
self.player_alive = False
|
|
|
|
if self.player_respawns_remaining > 0:
|
|
self.player_respawns_remaining -= 1
|
|
self._respawn_player()
|
|
else:
|
|
terminated = True
|
|
|
|
if not self.enemy_formation.enemies:
|
|
reward += 50.0
|
|
|
|
if self.enemy_respawns_remaining > 0:
|
|
self.enemy_respawns_remaining -= 1
|
|
self._respawn_enemies()
|
|
reward += 20.0
|
|
else:
|
|
reward += 100.0
|
|
terminated = True
|
|
|
|
shooting_prob = 0.05 + (0.05 * (1.0 - len(self.enemy_formation.enemies) / (self.difficulty_settings["enemy_rows"] * self.difficulty_settings["enemy_cols"])))
|
|
if self.enemy_formation.enemies and random.random() < shooting_prob:
|
|
enemy = self.enemy_formation.get_lowest_enemy()
|
|
if enemy:
|
|
b = Bullet(enemy.center_x, enemy.center_y, -1)
|
|
self.bullets.append(b)
|
|
|
|
if self.player_alive:
|
|
edge_threshold = self.width * 0.1
|
|
if self.player.center_x < edge_threshold or self.player.center_x > self.width - edge_threshold:
|
|
reward -= 0.03
|
|
|
|
reward -= 0.01
|
|
|
|
obs = self._obs()
|
|
|
|
return obs, float(reward), bool(terminated), bool(truncated), {
|
|
"enemies_killed": self.enemies_killed,
|
|
"step": self.current_step,
|
|
"player_respawns_remaining": self.player_respawns_remaining,
|
|
"enemy_respawns_remaining": self.enemy_respawns_remaining
|
|
} |