mirror of
https://github.com/csd4ni3l/fleet-commander.git
synced 2026-01-01 04:23:47 +01:00
Add model training with graphs and current stats, improve model with better rewarding system
This commit is contained in:
@@ -12,18 +12,21 @@ PLAYER_ATTACK_SPEED = 0.75
|
||||
BULLET_SPEED = 3
|
||||
BULLET_RADIUS = 10
|
||||
|
||||
# default, min, max, step
|
||||
MODEL_SETTINGS = {
|
||||
"n_steps": [2048, 256, 8192, 256],
|
||||
"batch_size": 64,
|
||||
"n_epochs": 10,
|
||||
"learning_rate": 3e-4,
|
||||
"gamma": 0.99,
|
||||
"ent_coef": 0.01,
|
||||
"clip_range": 0.2
|
||||
"batch_size": [64, 16, 512, 16],
|
||||
"n_epochs": [10, 1, 50, 1],
|
||||
"learning_rate": [3e-4, 1e-5, 1e-2, 1e-5],
|
||||
"gamma": [0.99, 0.8, 0.9999, 0.001],
|
||||
"ent_coef": [0.01, 0.0, 0.1, 0.001],
|
||||
"clip_range": [0.2, 0.1, 0.4, 0.01],
|
||||
"learning_steps": [500_000, 50_000, 25_000_000, 50_000]
|
||||
}
|
||||
|
||||
menu_background_color = (30, 30, 47)
|
||||
log_dir = 'logs'
|
||||
monitor_log_dir = "training_logs"
|
||||
discord_presence_id = 1438214877343907881
|
||||
|
||||
button_style = {'normal': UITextureButtonStyle(font_name="Roboto", font_color=arcade.color.BLACK), 'hover': UITextureButtonStyle(font_name="Roboto", font_color=arcade.color.BLACK),
|
||||
|
||||
@@ -13,7 +13,7 @@ class SpaceInvadersEnv(gym.Env):
|
||||
self.height = height
|
||||
|
||||
self.action_space = gym.spaces.Discrete(3)
|
||||
self.observation_space = gym.spaces.Box(low=-10.0, high=10.0, shape=(9,), dtype=np.float32)
|
||||
self.observation_space = gym.spaces.Box(low=-2.0, high=2.0, shape=(9,), dtype=np.float32)
|
||||
|
||||
self.enemies = []
|
||||
self.bullets = []
|
||||
@@ -25,8 +25,14 @@ class SpaceInvadersEnv(gym.Env):
|
||||
self.prev_bx = 2.0
|
||||
self.steps_since_direction_change = 0
|
||||
self.last_direction = 0
|
||||
self.max_steps = 1000
|
||||
self.current_step = 0
|
||||
|
||||
def reset(self, seed=None, options=None):
|
||||
if seed is not None:
|
||||
np.random.seed(seed)
|
||||
random.seed(seed)
|
||||
|
||||
self.enemies = []
|
||||
self.bullets = []
|
||||
self.dir_history = []
|
||||
@@ -36,6 +42,7 @@ class SpaceInvadersEnv(gym.Env):
|
||||
self.prev_bx = 2.0
|
||||
self.steps_since_direction_change = 0
|
||||
self.last_direction = 0
|
||||
self.current_step = 0
|
||||
|
||||
start_x = self.width * 0.15
|
||||
start_y = self.height * 0.9
|
||||
@@ -51,7 +58,7 @@ class SpaceInvadersEnv(gym.Env):
|
||||
def _nearest_enemy(self):
|
||||
if not self.enemies:
|
||||
return None
|
||||
return min(self.enemies, key=lambda e: abs(e.center_y - self.player.center_y) + abs(e.center_x - self.player.center_x))
|
||||
return min(self.enemies, key=lambda e: abs(e.center_x - self.player.center_x))
|
||||
|
||||
def _lowest_enemy(self):
|
||||
if not self.enemies:
|
||||
@@ -104,18 +111,16 @@ class SpaceInvadersEnv(gym.Env):
|
||||
terminated = False
|
||||
truncated = False
|
||||
|
||||
self.current_step += 1
|
||||
if self.current_step >= self.max_steps:
|
||||
truncated = True
|
||||
|
||||
nearest = self._nearest_enemy()
|
||||
if nearest is not None:
|
||||
enemy_x = (nearest.center_x - self.player.center_x) / float(self.width)
|
||||
else:
|
||||
enemy_x = 2.0
|
||||
|
||||
prev_bullet = self._nearest_enemy_bullet()
|
||||
if prev_bullet is not None:
|
||||
prev_bx = (prev_bullet.center_x - self.player.center_x) / float(self.width)
|
||||
else:
|
||||
prev_bx = 2.0
|
||||
|
||||
prev_x = self.player.center_x
|
||||
current_action_dir = 0
|
||||
|
||||
@@ -129,12 +134,15 @@ class SpaceInvadersEnv(gym.Env):
|
||||
t = time.perf_counter()
|
||||
if t - self.last_shot >= PLAYER_ATTACK_SPEED:
|
||||
self.last_shot = t
|
||||
|
||||
b = Bullet(self.player.center_x, self.player.center_y, 1)
|
||||
|
||||
self.bullets.append(b)
|
||||
|
||||
if enemy_x != 2.0 and abs(enemy_x) < 0.04:
|
||||
reward += 8.0
|
||||
reward += 0.3
|
||||
elif enemy_x != 2.0 and abs(enemy_x) < 0.1:
|
||||
reward += 3.0
|
||||
reward += 0.1
|
||||
|
||||
if self.player.center_x > self.width:
|
||||
self.player.center_x = self.width
|
||||
@@ -145,8 +153,9 @@ class SpaceInvadersEnv(gym.Env):
|
||||
|
||||
if current_action_dir != 0:
|
||||
if self.last_direction != 0 and current_action_dir != self.last_direction:
|
||||
if self.steps_since_direction_change < 8:
|
||||
reward -= 3.0
|
||||
if self.steps_since_direction_change < 3:
|
||||
reward -= 0.1
|
||||
|
||||
self.steps_since_direction_change = 0
|
||||
else:
|
||||
self.steps_since_direction_change += 1
|
||||
@@ -154,9 +163,9 @@ class SpaceInvadersEnv(gym.Env):
|
||||
|
||||
if enemy_x != 2.0:
|
||||
if abs(enemy_x) < 0.03:
|
||||
reward += 3.0
|
||||
reward += 0.1
|
||||
elif abs(enemy_x) < 0.08:
|
||||
reward += 1.0
|
||||
reward += 0.05
|
||||
|
||||
for b in list(self.bullets):
|
||||
b.center_y += b.direction_y * BULLET_SPEED
|
||||
@@ -178,7 +187,7 @@ class SpaceInvadersEnv(gym.Env):
|
||||
self.bullets.remove(b)
|
||||
except ValueError:
|
||||
pass
|
||||
reward += 25.0
|
||||
reward += 1.0
|
||||
break
|
||||
|
||||
for b in list(self.bullets):
|
||||
@@ -188,14 +197,14 @@ class SpaceInvadersEnv(gym.Env):
|
||||
self.bullets.remove(b)
|
||||
except ValueError:
|
||||
pass
|
||||
reward -= 100.0
|
||||
reward -= 5.0
|
||||
terminated = True
|
||||
|
||||
if not self.enemies:
|
||||
reward += 200.0
|
||||
reward += 10.0
|
||||
terminated = True
|
||||
|
||||
if self.enemies and random.random() < 0.02:
|
||||
if self.enemies and random.random() < 0.05:
|
||||
e = random.choice(self.enemies)
|
||||
b = Bullet(e.center_x, e.center_y, -1)
|
||||
self.bullets.append(b)
|
||||
@@ -203,17 +212,14 @@ class SpaceInvadersEnv(gym.Env):
|
||||
curr_bullet = self._nearest_enemy_bullet()
|
||||
if curr_bullet is not None:
|
||||
curr_bx = (curr_bullet.center_x - self.player.center_x) / float(self.width)
|
||||
curr_by = (curr_bullet.center_y - self.player.center_y) / float(self.height)
|
||||
else:
|
||||
curr_bx = 2.0
|
||||
curr_by = 2.0
|
||||
|
||||
if prev_bx != 2.0 and curr_bx != 2.0:
|
||||
if abs(curr_bx) > abs(prev_bx):
|
||||
reward += 0.3
|
||||
if self.prev_bx != 2.0 and curr_bx != 2.0:
|
||||
if abs(curr_bx) > abs(self.prev_bx):
|
||||
reward += 0.02
|
||||
|
||||
if curr_bx != 2.0 and abs(curr_bx) < 0.08 and curr_by < 0.5:
|
||||
reward -= 0.3
|
||||
reward -= 0.01
|
||||
|
||||
obs = self._obs()
|
||||
self.prev_bx = curr_bx
|
||||
Reference in New Issue
Block a user