diff --git a/invader_agent.zip b/invader_agent.zip deleted file mode 100644 index 49979f2..0000000 Binary files a/invader_agent.zip and /dev/null differ diff --git a/train.py b/train.py new file mode 100644 index 0000000..af3142e --- /dev/null +++ b/train.py @@ -0,0 +1,30 @@ +from stable_baselines3 import PPO +from utils.rl import SpaceInvadersEnv +from stable_baselines3.common.vec_env import DummyVecEnv + +def make_env(rank: int, seed: int = 0): + def _init(): + env = SpaceInvadersEnv() + return env + return _init + +env = SpaceInvadersEnv() + +n_envs = 128 + +env = DummyVecEnv([make_env(i) for i in range(n_envs)]) +model = PPO( + "MlpPolicy", + env, + n_steps=8192, + batch_size=256, + n_epochs=7, + learning_rate=0.001, + verbose=1, + device="cpu", + gamma=0.985, + ent_coef=0.015, + clip_range=0.2, +) +model.learn(75_000_000) +model.save("invader_agent") \ No newline at end of file diff --git a/utils/constants.py b/utils/constants.py index e226ba1..2a891d9 100644 --- a/utils/constants.py +++ b/utils/constants.py @@ -35,29 +35,29 @@ DIFFICULTY_SETTINGS = { DIFFICULTY_LEVELS = { "Easy": { - "enemy_rows": 3, - "enemy_cols": 4, + "enemy_rows": 2, + "enemy_cols": 3, "enemy_respawns": 5, "player_count": 2, "player_respawns": 2 }, "Medium": { "enemy_rows": 3, - "enemy_cols": 5, + "enemy_cols": 4, "enemy_respawns": 4, "player_count": 4, "player_respawns": 3 }, "Hard": { "enemy_rows": 4, - "enemy_cols": 6, + "enemy_cols": 5, "enemy_respawns": 3, "player_count": 6, "player_respawns": 4 }, "Extra Hard": { - "enemy_rows": 6, - "enemy_cols": 7, + "enemy_rows": 5, + "enemy_cols": 6, "enemy_respawns": 2, "player_count": 8, "player_respawns": 5 diff --git a/utils/rl.py b/utils/rl.py index 6b42cc0..d846403 100644 --- a/utils/rl.py +++ b/utils/rl.py @@ -172,13 +172,13 @@ class SpaceInvadersEnv(gym.Env): b = Bullet(self.player.center_x, self.player.center_y, 1) self.bullets.append(b) else: - reward -= 0.05 + reward -= 0.02 if self.enemy_formation.enemies: nearest = self._nearest_enemy() alignment = abs(nearest.center_x - self.player.center_x) / self.width if alignment < 0.025: - reward += 0.3 + reward += 0.1 self.player.center_x = np.clip(self.player.center_x, 0, self.width) self.player_speed = (self.player.center_x - prev_x) / max(1e-6, PLAYER_SPEED) @@ -222,7 +222,6 @@ class SpaceInvadersEnv(gym.Env): if self.player_respawns_remaining > 0: self.player_respawns_remaining -= 1 self._respawn_player() - reward += 2.0 else: terminated = True @@ -265,7 +264,7 @@ class SpaceInvadersEnv(gym.Env): if self.player.center_x < edge_threshold or self.player.center_x > self.width - edge_threshold: reward -= 0.03 - reward -= 0.005 + reward -= 0.0025 obs = self._obs()