Remove default model i will add back later, add back train.py, fix some rewarding

This commit is contained in:
csd4ni3l
2025-11-16 22:34:29 +01:00
parent 5e87b30f78
commit c7c22695e5
4 changed files with 39 additions and 10 deletions

Binary file not shown.

30
train.py Normal file
View File

@@ -0,0 +1,30 @@
from stable_baselines3 import PPO
from utils.rl import SpaceInvadersEnv
from stable_baselines3.common.vec_env import DummyVecEnv
def make_env(rank: int, seed: int = 0):
def _init():
env = SpaceInvadersEnv()
return env
return _init
env = SpaceInvadersEnv()
n_envs = 128
env = DummyVecEnv([make_env(i) for i in range(n_envs)])
model = PPO(
"MlpPolicy",
env,
n_steps=8192,
batch_size=256,
n_epochs=7,
learning_rate=0.001,
verbose=1,
device="cpu",
gamma=0.985,
ent_coef=0.015,
clip_range=0.2,
)
model.learn(75_000_000)
model.save("invader_agent")

View File

@@ -35,29 +35,29 @@ DIFFICULTY_SETTINGS = {
DIFFICULTY_LEVELS = {
"Easy": {
"enemy_rows": 3,
"enemy_cols": 4,
"enemy_rows": 2,
"enemy_cols": 3,
"enemy_respawns": 5,
"player_count": 2,
"player_respawns": 2
},
"Medium": {
"enemy_rows": 3,
"enemy_cols": 5,
"enemy_cols": 4,
"enemy_respawns": 4,
"player_count": 4,
"player_respawns": 3
},
"Hard": {
"enemy_rows": 4,
"enemy_cols": 6,
"enemy_cols": 5,
"enemy_respawns": 3,
"player_count": 6,
"player_respawns": 4
},
"Extra Hard": {
"enemy_rows": 6,
"enemy_cols": 7,
"enemy_rows": 5,
"enemy_cols": 6,
"enemy_respawns": 2,
"player_count": 8,
"player_respawns": 5

View File

@@ -172,13 +172,13 @@ class SpaceInvadersEnv(gym.Env):
b = Bullet(self.player.center_x, self.player.center_y, 1)
self.bullets.append(b)
else:
reward -= 0.05
reward -= 0.02
if self.enemy_formation.enemies:
nearest = self._nearest_enemy()
alignment = abs(nearest.center_x - self.player.center_x) / self.width
if alignment < 0.025:
reward += 0.3
reward += 0.1
self.player.center_x = np.clip(self.player.center_x, 0, self.width)
self.player_speed = (self.player.center_x - prev_x) / max(1e-6, PLAYER_SPEED)
@@ -222,7 +222,6 @@ class SpaceInvadersEnv(gym.Env):
if self.player_respawns_remaining > 0:
self.player_respawns_remaining -= 1
self._respawn_player()
reward += 2.0
else:
terminated = True
@@ -265,7 +264,7 @@ class SpaceInvadersEnv(gym.Env):
if self.player.center_x < edge_threshold or self.player.center_x > self.width - edge_threshold:
reward -= 0.03
reward -= 0.005
reward -= 0.0025
obs = self._obs()