mirror of
https://github.com/csd4ni3l/fleet-commander.git
synced 2026-01-01 04:23:47 +01:00
Remove default model i will add back later, add back train.py, fix some rewarding
This commit is contained in:
Binary file not shown.
30
train.py
Normal file
30
train.py
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
from stable_baselines3 import PPO
|
||||||
|
from utils.rl import SpaceInvadersEnv
|
||||||
|
from stable_baselines3.common.vec_env import DummyVecEnv
|
||||||
|
|
||||||
|
def make_env(rank: int, seed: int = 0):
|
||||||
|
def _init():
|
||||||
|
env = SpaceInvadersEnv()
|
||||||
|
return env
|
||||||
|
return _init
|
||||||
|
|
||||||
|
env = SpaceInvadersEnv()
|
||||||
|
|
||||||
|
n_envs = 128
|
||||||
|
|
||||||
|
env = DummyVecEnv([make_env(i) for i in range(n_envs)])
|
||||||
|
model = PPO(
|
||||||
|
"MlpPolicy",
|
||||||
|
env,
|
||||||
|
n_steps=8192,
|
||||||
|
batch_size=256,
|
||||||
|
n_epochs=7,
|
||||||
|
learning_rate=0.001,
|
||||||
|
verbose=1,
|
||||||
|
device="cpu",
|
||||||
|
gamma=0.985,
|
||||||
|
ent_coef=0.015,
|
||||||
|
clip_range=0.2,
|
||||||
|
)
|
||||||
|
model.learn(75_000_000)
|
||||||
|
model.save("invader_agent")
|
||||||
@@ -35,29 +35,29 @@ DIFFICULTY_SETTINGS = {
|
|||||||
|
|
||||||
DIFFICULTY_LEVELS = {
|
DIFFICULTY_LEVELS = {
|
||||||
"Easy": {
|
"Easy": {
|
||||||
"enemy_rows": 3,
|
"enemy_rows": 2,
|
||||||
"enemy_cols": 4,
|
"enemy_cols": 3,
|
||||||
"enemy_respawns": 5,
|
"enemy_respawns": 5,
|
||||||
"player_count": 2,
|
"player_count": 2,
|
||||||
"player_respawns": 2
|
"player_respawns": 2
|
||||||
},
|
},
|
||||||
"Medium": {
|
"Medium": {
|
||||||
"enemy_rows": 3,
|
"enemy_rows": 3,
|
||||||
"enemy_cols": 5,
|
"enemy_cols": 4,
|
||||||
"enemy_respawns": 4,
|
"enemy_respawns": 4,
|
||||||
"player_count": 4,
|
"player_count": 4,
|
||||||
"player_respawns": 3
|
"player_respawns": 3
|
||||||
},
|
},
|
||||||
"Hard": {
|
"Hard": {
|
||||||
"enemy_rows": 4,
|
"enemy_rows": 4,
|
||||||
"enemy_cols": 6,
|
"enemy_cols": 5,
|
||||||
"enemy_respawns": 3,
|
"enemy_respawns": 3,
|
||||||
"player_count": 6,
|
"player_count": 6,
|
||||||
"player_respawns": 4
|
"player_respawns": 4
|
||||||
},
|
},
|
||||||
"Extra Hard": {
|
"Extra Hard": {
|
||||||
"enemy_rows": 6,
|
"enemy_rows": 5,
|
||||||
"enemy_cols": 7,
|
"enemy_cols": 6,
|
||||||
"enemy_respawns": 2,
|
"enemy_respawns": 2,
|
||||||
"player_count": 8,
|
"player_count": 8,
|
||||||
"player_respawns": 5
|
"player_respawns": 5
|
||||||
|
|||||||
@@ -172,13 +172,13 @@ class SpaceInvadersEnv(gym.Env):
|
|||||||
b = Bullet(self.player.center_x, self.player.center_y, 1)
|
b = Bullet(self.player.center_x, self.player.center_y, 1)
|
||||||
self.bullets.append(b)
|
self.bullets.append(b)
|
||||||
else:
|
else:
|
||||||
reward -= 0.05
|
reward -= 0.02
|
||||||
|
|
||||||
if self.enemy_formation.enemies:
|
if self.enemy_formation.enemies:
|
||||||
nearest = self._nearest_enemy()
|
nearest = self._nearest_enemy()
|
||||||
alignment = abs(nearest.center_x - self.player.center_x) / self.width
|
alignment = abs(nearest.center_x - self.player.center_x) / self.width
|
||||||
if alignment < 0.025:
|
if alignment < 0.025:
|
||||||
reward += 0.3
|
reward += 0.1
|
||||||
|
|
||||||
self.player.center_x = np.clip(self.player.center_x, 0, self.width)
|
self.player.center_x = np.clip(self.player.center_x, 0, self.width)
|
||||||
self.player_speed = (self.player.center_x - prev_x) / max(1e-6, PLAYER_SPEED)
|
self.player_speed = (self.player.center_x - prev_x) / max(1e-6, PLAYER_SPEED)
|
||||||
@@ -222,7 +222,6 @@ class SpaceInvadersEnv(gym.Env):
|
|||||||
if self.player_respawns_remaining > 0:
|
if self.player_respawns_remaining > 0:
|
||||||
self.player_respawns_remaining -= 1
|
self.player_respawns_remaining -= 1
|
||||||
self._respawn_player()
|
self._respawn_player()
|
||||||
reward += 2.0
|
|
||||||
else:
|
else:
|
||||||
terminated = True
|
terminated = True
|
||||||
|
|
||||||
@@ -265,7 +264,7 @@ class SpaceInvadersEnv(gym.Env):
|
|||||||
if self.player.center_x < edge_threshold or self.player.center_x > self.width - edge_threshold:
|
if self.player.center_x < edge_threshold or self.player.center_x > self.width - edge_threshold:
|
||||||
reward -= 0.03
|
reward -= 0.03
|
||||||
|
|
||||||
reward -= 0.005
|
reward -= 0.0025
|
||||||
|
|
||||||
obs = self._obs()
|
obs = self._obs()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user