mirror of
https://github.com/csd4ni3l/fleet-commander.git
synced 2026-01-01 04:23:47 +01:00
Add model training with graphs and current stats, improve model with better rewarding system
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -180,3 +180,4 @@ test*.py
|
|||||||
logs/
|
logs/
|
||||||
logs
|
logs
|
||||||
settings.json
|
settings.json
|
||||||
|
training_logs
|
||||||
@@ -1 +1,5 @@
|
|||||||
Fleet Commander is like Space Invaders but you are the enemy instead of the player.
|
Fleet Commander is like Space Invaders but you are the enemy instead of the player.
|
||||||
|
|
||||||
|
It uses AI (Reinforcement Learning) for the Player, and you, the Enemy has to defeat it.
|
||||||
|
|
||||||
|
You can train yourself, or use the default model which comes with the game.
|
||||||
@@ -34,7 +34,7 @@ class Game(arcade.gui.UIView):
|
|||||||
def on_show_view(self):
|
def on_show_view(self):
|
||||||
super().on_show_view()
|
super().on_show_view()
|
||||||
|
|
||||||
self.back_button = self.anchor.add(arcade.gui.UITextureButton(texture=button_texture, texture_hovered=button_hovered_texture, text='<--', style=button_style, width=100, height=50), anchor_x="left", anchor_y="top")
|
self.back_button = self.anchor.add(arcade.gui.UITextureButton(texture=button_texture, texture_hovered=button_hovered_texture, text='<--', style=button_style, width=100, height=50), anchor_x="left", anchor_y="top", align_x=5, align_y=-5)
|
||||||
self.back_button.on_click = lambda event: self.main_exit()
|
self.back_button.on_click = lambda event: self.main_exit()
|
||||||
|
|
||||||
def main_exit(self):
|
def main_exit(self):
|
||||||
|
|||||||
@@ -33,7 +33,7 @@ class Player(arcade.Sprite): # Not actually the player
|
|||||||
|
|
||||||
def update(self, model: PPO, enemies, bullets, width, height):
|
def update(self, model: PPO, enemies, bullets, width, height):
|
||||||
if enemies:
|
if enemies:
|
||||||
nearest_enemy = min(enemies, key=lambda e: abs(e.center_y - self.center_y) + abs(e.center_x - self.center_x))
|
nearest_enemy = min(enemies, key=lambda e: abs(e.center_x - self.center_x))
|
||||||
enemy_x = (nearest_enemy.center_x - self.center_x) / width
|
enemy_x = (nearest_enemy.center_x - self.center_x) / width
|
||||||
enemy_y = (nearest_enemy.center_y - self.center_y) / height
|
enemy_y = (nearest_enemy.center_y - self.center_y) / height
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -55,8 +55,8 @@ class Main(arcade.gui.UIView):
|
|||||||
self.play_button = self.box.add(arcade.gui.UITextureButton(text="Play", texture=button_texture, texture_hovered=button_hovered_texture, width=self.window.width / 2, height=150, style=big_button_style))
|
self.play_button = self.box.add(arcade.gui.UITextureButton(text="Play", texture=button_texture, texture_hovered=button_hovered_texture, width=self.window.width / 2, height=150, style=big_button_style))
|
||||||
self.play_button.on_click = lambda event: self.play()
|
self.play_button.on_click = lambda event: self.play()
|
||||||
|
|
||||||
self.train_button = self.box.add(arcade.gui.UITextureButton(text="Train", texture=button_texture, texture_hovered=button_hovered_texture, width=self.window.width / 2, height=150, style=big_button_style))
|
self.train_model_button = self.box.add(arcade.gui.UITextureButton(text="Train Model", texture=button_texture, texture_hovered=button_hovered_texture, width=self.window.width / 2, height=150, style=big_button_style))
|
||||||
self.train_button.on_click = lambda event: self.train()
|
self.train_model_button.on_click = lambda event: self.train_model()
|
||||||
|
|
||||||
self.settings_button = self.box.add(arcade.gui.UITextureButton(text="Settings", texture=button_texture, texture_hovered=button_hovered_texture, width=self.window.width / 2, height=150, style=big_button_style))
|
self.settings_button = self.box.add(arcade.gui.UITextureButton(text="Settings", texture=button_texture, texture_hovered=button_hovered_texture, width=self.window.width / 2, height=150, style=big_button_style))
|
||||||
self.settings_button.on_click = lambda event: self.settings()
|
self.settings_button.on_click = lambda event: self.settings()
|
||||||
@@ -68,3 +68,7 @@ class Main(arcade.gui.UIView):
|
|||||||
def settings(self):
|
def settings(self):
|
||||||
from menus.settings import Settings
|
from menus.settings import Settings
|
||||||
self.window.show_view(Settings(self.pypresence_client))
|
self.window.show_view(Settings(self.pypresence_client))
|
||||||
|
|
||||||
|
def train_model(self):
|
||||||
|
from menus.train_model import TrainModel
|
||||||
|
self.window.show_view(TrainModel(self.pypresence_client))
|
||||||
@@ -1,60 +1,177 @@
|
|||||||
import arcade, arcade.gui
|
import arcade, arcade.gui, threading, io, os, time
|
||||||
|
|
||||||
from utils.constants import button_style, MODEL_SETTINGS
|
from utils.constants import button_style, MODEL_SETTINGS, monitor_log_dir
|
||||||
from utils.preload import button_texture, button_hovered_texture
|
from utils.preload import button_texture, button_hovered_texture
|
||||||
|
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
from io import BytesIO
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
from stable_baselines3 import PPO
|
from stable_baselines3 import PPO
|
||||||
from utils.ml import SpaceInvadersEnv
|
from stable_baselines3.common.monitor import Monitor
|
||||||
|
from stable_baselines3.common.logger import configure
|
||||||
|
|
||||||
|
from utils.rl import SpaceInvadersEnv
|
||||||
|
|
||||||
class TrainModel(arcade.gui.UIView):
|
class TrainModel(arcade.gui.UIView):
|
||||||
def __init__(self, pypresence_client):
|
def __init__(self, pypresence_client):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
self.pypresence_client = pypresence_client
|
self.pypresence_client = pypresence_client
|
||||||
self.pypresence_client.update(state="Model Training")
|
self.pypresence_client.update(state="Model Training")
|
||||||
|
|
||||||
self.current_state = "settings"
|
|
||||||
|
|
||||||
self.anchor = self.add_widget(arcade.gui.UIAnchorLayout(size_hint=(1, 1)))
|
self.anchor = self.add_widget(arcade.gui.UIAnchorLayout(size_hint=(1, 1)))
|
||||||
self.box = self.anchor.add(arcade.gui.UIBoxLayout(space_between=10))
|
self.box = self.anchor.add(arcade.gui.UIBoxLayout(space_between=10))
|
||||||
|
|
||||||
self.settings = MODEL_SETTINGS.copy()
|
self.settings = {
|
||||||
|
setting: data[0] # default value
|
||||||
|
for setting, data in MODEL_SETTINGS.items()
|
||||||
|
}
|
||||||
|
self.labels = {}
|
||||||
|
|
||||||
|
self.training = False
|
||||||
|
self.training_text = ""
|
||||||
|
|
||||||
|
self.last_progress_update = time.perf_counter()
|
||||||
|
|
||||||
def on_show_view(self):
|
def on_show_view(self):
|
||||||
super().on_show_view()
|
super().on_show_view()
|
||||||
|
|
||||||
self.show_menu(self.current_state)
|
self.show_menu()
|
||||||
|
|
||||||
def show_menu(self, state):
|
def main_exit(self):
|
||||||
if state == "settings":
|
from menus.main import Main
|
||||||
self.box.add(arcade.gui.UILabel("Settings", font_size=48))
|
self.window.show_view(Main(self.pypresence_client))
|
||||||
|
|
||||||
for setting, data in MODEL_SETTINGS:
|
def show_menu(self):
|
||||||
default, min_value, max_value, step = data
|
self.back_button = self.anchor.add(arcade.gui.UITextureButton(texture=button_texture, texture_hovered=button_hovered_texture, text='<--', style=button_style, width=100, height=50), anchor_x="left", anchor_y="top", align_x=5, align_y=-5)
|
||||||
self.box.add(arcade.gui.UILabel(text=f"{setting.replace('_', ' ').capitalize()}: {default}"))
|
self.back_button.on_click = lambda event: self.main_exit()
|
||||||
|
|
||||||
slider = self.box.add(arcade.gui.UISlider(value=default, min_value=min_value, max_value=max_value, step=step))
|
self.box.add(arcade.gui.UILabel("Settings", font_size=36))
|
||||||
slider._render_steps = lambda surface: None
|
|
||||||
slider.on_change = lambda e, key=setting: self.change_value(key, e.new_value)
|
|
||||||
|
|
||||||
train_button = self.box.add(arcade.gui.UITextureButton(width=self.window.width / 2, height=self.window.height / 10, text="Train", style=button_style, texture=button_texture, texture_hovered=button_hovered_texture))
|
for setting, data in MODEL_SETTINGS.items():
|
||||||
train_button.on_click = lambda e: self.train()
|
default, min_value, max_value, step = data
|
||||||
|
label = self.box.add(arcade.gui.UILabel(text=f"{setting.replace('_', ' ').capitalize()}: {default}", font_size=18))
|
||||||
|
|
||||||
|
slider = self.box.add(arcade.gui.UISlider(value=default, min_value=min_value, max_value=max_value, step=step, width=self.window.width / 2, height=self.window.height / 25))
|
||||||
|
slider._render_steps = lambda surface: None
|
||||||
|
slider.on_change = lambda e, key=setting: self.change_value(key, e.new_value)
|
||||||
|
|
||||||
|
self.labels[setting] = label
|
||||||
|
|
||||||
|
train_button = self.box.add(arcade.gui.UITextureButton(width=self.window.width / 2, height=self.window.height / 10, text="Train", style=button_style, texture=button_texture, texture_hovered=button_hovered_texture))
|
||||||
|
train_button.on_click = lambda e: self.start_training()
|
||||||
|
|
||||||
def change_value(self, key, value):
|
def change_value(self, key, value):
|
||||||
...
|
self.labels[key].text = f"{key.replace('_', ' ').capitalize()}: {self.round_near_int(value)}"
|
||||||
|
self.settings[key] = self.round_near_int(value)
|
||||||
|
|
||||||
|
def start_training(self):
|
||||||
|
self.box.clear()
|
||||||
|
|
||||||
|
self.training = True
|
||||||
|
|
||||||
|
self.training_label = self.box.add(arcade.gui.UILabel("No Output yet.", font_size=16, multiline=True, width=self.window.width / 2, height=self.window.height / 2))
|
||||||
|
|
||||||
|
self.plot_image_widget = self.box.add(arcade.gui.UIImage(texture=arcade.Texture.create_empty("empty", (1, 1))))
|
||||||
|
self.plot_image_widget.visible = False
|
||||||
|
|
||||||
|
threading.Thread(target=self.train, daemon=True).start()
|
||||||
|
|
||||||
|
def on_update(self, delta_time):
|
||||||
|
if self.training and os.path.exists(os.path.join("training_logs", "progress.csv")) and time.perf_counter() - self.last_progress_update >= 0.5:
|
||||||
|
self.last_progress_update = time.perf_counter()
|
||||||
|
|
||||||
|
try:
|
||||||
|
progress_df = pd.read_csv(os.path.join("training_logs", "progress.csv"))
|
||||||
|
except pd.errors.EmptyDataError:
|
||||||
|
return
|
||||||
|
|
||||||
|
progress_text = ""
|
||||||
|
|
||||||
|
for key, value in progress_df.items():
|
||||||
|
progress_text += f"{key}: {round(value.iloc[-1], 6)}\n"
|
||||||
|
|
||||||
|
self.training_text = progress_text
|
||||||
|
|
||||||
|
if hasattr(self, "training_label"):
|
||||||
|
self.training_label.text = self.training_text
|
||||||
|
|
||||||
|
def round_near_int(self, x, tol=1e-4):
|
||||||
|
nearest = round(x)
|
||||||
|
if abs(x - nearest) < tol:
|
||||||
|
return nearest
|
||||||
|
return x
|
||||||
|
|
||||||
def train(self):
|
def train(self):
|
||||||
env = SpaceInvadersEnv()
|
os.makedirs(monitor_log_dir, exist_ok=True)
|
||||||
|
env = Monitor(SpaceInvadersEnv(), filename=os.path.join(monitor_log_dir, "monitor.csv"))
|
||||||
|
|
||||||
model = PPO(
|
model = PPO(
|
||||||
"MlpPolicy",
|
"MlpPolicy",
|
||||||
env,
|
env,
|
||||||
n_steps=2048,
|
n_steps=self.settings["n_steps"],
|
||||||
batch_size=64,
|
batch_size=self.settings["batch_size"],
|
||||||
n_epochs=10,
|
n_epochs=self.settings["n_epochs"],
|
||||||
learning_rate=3e-4,
|
learning_rate=self.settings["learning_rate"],
|
||||||
verbose=1,
|
verbose=1,
|
||||||
device="cpu",
|
device="cpu",
|
||||||
gamma=0.99,
|
gamma=self.settings["gamma"],
|
||||||
ent_coef=0.01,
|
ent_coef=self.settings["ent_coef"],
|
||||||
clip_range=0.2
|
clip_range=self.settings["clip_range"],
|
||||||
)
|
)
|
||||||
model.learn(1_000_000)
|
|
||||||
|
new_logger = configure(
|
||||||
|
folder=monitor_log_dir, format_strings=["csv"]
|
||||||
|
)
|
||||||
|
|
||||||
|
model.set_logger(new_logger)
|
||||||
|
|
||||||
|
model.learn(self.settings["learning_steps"])
|
||||||
model.save("invader_agent")
|
model.save("invader_agent")
|
||||||
|
|
||||||
|
self.training = False
|
||||||
|
|
||||||
|
self.plot_results(os.path.join(monitor_log_dir, "monitor.csv"), os.path.join(monitor_log_dir, "progress.csv"))
|
||||||
|
|
||||||
|
def plot_results(self, log_path, loss_log_path):
|
||||||
|
df = pd.read_csv(log_path, skiprows=1)
|
||||||
|
|
||||||
|
fig, axes = plt.subplots(2, 1, figsize=(6, 8), dpi=100)
|
||||||
|
|
||||||
|
loss_df = pd.read_csv(loss_log_path)
|
||||||
|
|
||||||
|
axes[0].plot(np.cumsum(df['l']), df['r'], label='Reward')
|
||||||
|
axes[0].set_title('PPO Training: Episodic Reward')
|
||||||
|
axes[0].set_xlabel('Total Timesteps')
|
||||||
|
axes[0].set_ylabel('Reward')
|
||||||
|
axes[0].grid(True)
|
||||||
|
|
||||||
|
axes[1].plot(loss_df['time/total_timesteps'], loss_df['train/policy_gradient_loss'], label='Policy Gradient Loss')
|
||||||
|
axes[1].plot(loss_df['time/total_timesteps'], loss_df['train/value_loss'], label='Value Loss')
|
||||||
|
axes[1].plot(loss_df['time/total_timesteps'], loss_df['train/explained_variance'], label='Explained Variance')
|
||||||
|
axes[1].set_title('PPO Training: Loss Functions')
|
||||||
|
axes[1].set_xlabel('Total Timesteps')
|
||||||
|
axes[1].set_ylabel('Loss Value')
|
||||||
|
axes[1].legend()
|
||||||
|
axes[1].grid(True)
|
||||||
|
|
||||||
|
plt.tight_layout()
|
||||||
|
|
||||||
|
buffer = BytesIO()
|
||||||
|
plt.savefig(buffer, format='png')
|
||||||
|
buffer.seek(0)
|
||||||
|
plt.close(fig)
|
||||||
|
|
||||||
|
plot_texture = arcade.Texture(Image.open(buffer))
|
||||||
|
|
||||||
|
self.plot_image_widget.texture = plot_texture
|
||||||
|
self.plot_image_widget.size_hint = (None, None)
|
||||||
|
self.plot_image_widget.width = plot_texture.width
|
||||||
|
self.plot_image_widget.height = plot_texture.height
|
||||||
|
|
||||||
|
self.plot_image_widget.visible = True
|
||||||
|
self.training_text = "Training finished. Plot displayed."
|
||||||
3
run.py
3
run.py
@@ -10,7 +10,6 @@ script_dir = os.path.dirname(os.path.abspath(__file__))
|
|||||||
pyglet.resource.path.append(script_dir)
|
pyglet.resource.path.append(script_dir)
|
||||||
pyglet.font.add_directory(os.path.join(script_dir, 'assets', 'fonts'))
|
pyglet.font.add_directory(os.path.join(script_dir, 'assets', 'fonts'))
|
||||||
|
|
||||||
|
|
||||||
from utils.utils import get_closest_resolution, print_debug_info, on_exception
|
from utils.utils import get_closest_resolution, print_debug_info, on_exception
|
||||||
from utils.constants import log_dir, menu_background_color
|
from utils.constants import log_dir, menu_background_color
|
||||||
from menus.main import Main
|
from menus.main import Main
|
||||||
@@ -18,8 +17,6 @@ from arcade.experimental.controller_window import ControllerWindow
|
|||||||
|
|
||||||
sys.excepthook = on_exception
|
sys.excepthook = on_exception
|
||||||
|
|
||||||
__builtins__.print = lambda *args, **kwargs: logging.debug(" ".join(map(str, args)))
|
|
||||||
|
|
||||||
if not log_dir in os.listdir():
|
if not log_dir in os.listdir():
|
||||||
os.makedirs(log_dir)
|
os.makedirs(log_dir)
|
||||||
|
|
||||||
|
|||||||
20
train.py
20
train.py
@@ -1,20 +0,0 @@
|
|||||||
from stable_baselines3 import PPO
|
|
||||||
from utils.ml import SpaceInvadersEnv
|
|
||||||
|
|
||||||
env = SpaceInvadersEnv()
|
|
||||||
model = PPO(
|
|
||||||
"MlpPolicy",
|
|
||||||
env,
|
|
||||||
n_steps=2048,
|
|
||||||
batch_size=64,
|
|
||||||
n_epochs=10,
|
|
||||||
learning_rate=3e-4,
|
|
||||||
verbose=1,
|
|
||||||
device="cpu",
|
|
||||||
gamma=0.99,
|
|
||||||
ent_coef=0.02,
|
|
||||||
clip_range=0.2,
|
|
||||||
gae_lambda=0.95
|
|
||||||
)
|
|
||||||
model.learn(1_000_000)
|
|
||||||
model.save("invader_agent")
|
|
||||||
@@ -12,18 +12,21 @@ PLAYER_ATTACK_SPEED = 0.75
|
|||||||
BULLET_SPEED = 3
|
BULLET_SPEED = 3
|
||||||
BULLET_RADIUS = 10
|
BULLET_RADIUS = 10
|
||||||
|
|
||||||
|
# default, min, max, step
|
||||||
MODEL_SETTINGS = {
|
MODEL_SETTINGS = {
|
||||||
"n_steps": [2048, 256, 8192, 256],
|
"n_steps": [2048, 256, 8192, 256],
|
||||||
"batch_size": 64,
|
"batch_size": [64, 16, 512, 16],
|
||||||
"n_epochs": 10,
|
"n_epochs": [10, 1, 50, 1],
|
||||||
"learning_rate": 3e-4,
|
"learning_rate": [3e-4, 1e-5, 1e-2, 1e-5],
|
||||||
"gamma": 0.99,
|
"gamma": [0.99, 0.8, 0.9999, 0.001],
|
||||||
"ent_coef": 0.01,
|
"ent_coef": [0.01, 0.0, 0.1, 0.001],
|
||||||
"clip_range": 0.2
|
"clip_range": [0.2, 0.1, 0.4, 0.01],
|
||||||
|
"learning_steps": [500_000, 50_000, 25_000_000, 50_000]
|
||||||
}
|
}
|
||||||
|
|
||||||
menu_background_color = (30, 30, 47)
|
menu_background_color = (30, 30, 47)
|
||||||
log_dir = 'logs'
|
log_dir = 'logs'
|
||||||
|
monitor_log_dir = "training_logs"
|
||||||
discord_presence_id = 1438214877343907881
|
discord_presence_id = 1438214877343907881
|
||||||
|
|
||||||
button_style = {'normal': UITextureButtonStyle(font_name="Roboto", font_color=arcade.color.BLACK), 'hover': UITextureButtonStyle(font_name="Roboto", font_color=arcade.color.BLACK),
|
button_style = {'normal': UITextureButtonStyle(font_name="Roboto", font_color=arcade.color.BLACK), 'hover': UITextureButtonStyle(font_name="Roboto", font_color=arcade.color.BLACK),
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ class SpaceInvadersEnv(gym.Env):
|
|||||||
self.height = height
|
self.height = height
|
||||||
|
|
||||||
self.action_space = gym.spaces.Discrete(3)
|
self.action_space = gym.spaces.Discrete(3)
|
||||||
self.observation_space = gym.spaces.Box(low=-10.0, high=10.0, shape=(9,), dtype=np.float32)
|
self.observation_space = gym.spaces.Box(low=-2.0, high=2.0, shape=(9,), dtype=np.float32)
|
||||||
|
|
||||||
self.enemies = []
|
self.enemies = []
|
||||||
self.bullets = []
|
self.bullets = []
|
||||||
@@ -25,8 +25,14 @@ class SpaceInvadersEnv(gym.Env):
|
|||||||
self.prev_bx = 2.0
|
self.prev_bx = 2.0
|
||||||
self.steps_since_direction_change = 0
|
self.steps_since_direction_change = 0
|
||||||
self.last_direction = 0
|
self.last_direction = 0
|
||||||
|
self.max_steps = 1000
|
||||||
|
self.current_step = 0
|
||||||
|
|
||||||
def reset(self, seed=None, options=None):
|
def reset(self, seed=None, options=None):
|
||||||
|
if seed is not None:
|
||||||
|
np.random.seed(seed)
|
||||||
|
random.seed(seed)
|
||||||
|
|
||||||
self.enemies = []
|
self.enemies = []
|
||||||
self.bullets = []
|
self.bullets = []
|
||||||
self.dir_history = []
|
self.dir_history = []
|
||||||
@@ -36,6 +42,7 @@ class SpaceInvadersEnv(gym.Env):
|
|||||||
self.prev_bx = 2.0
|
self.prev_bx = 2.0
|
||||||
self.steps_since_direction_change = 0
|
self.steps_since_direction_change = 0
|
||||||
self.last_direction = 0
|
self.last_direction = 0
|
||||||
|
self.current_step = 0
|
||||||
|
|
||||||
start_x = self.width * 0.15
|
start_x = self.width * 0.15
|
||||||
start_y = self.height * 0.9
|
start_y = self.height * 0.9
|
||||||
@@ -51,7 +58,7 @@ class SpaceInvadersEnv(gym.Env):
|
|||||||
def _nearest_enemy(self):
|
def _nearest_enemy(self):
|
||||||
if not self.enemies:
|
if not self.enemies:
|
||||||
return None
|
return None
|
||||||
return min(self.enemies, key=lambda e: abs(e.center_y - self.player.center_y) + abs(e.center_x - self.player.center_x))
|
return min(self.enemies, key=lambda e: abs(e.center_x - self.player.center_x))
|
||||||
|
|
||||||
def _lowest_enemy(self):
|
def _lowest_enemy(self):
|
||||||
if not self.enemies:
|
if not self.enemies:
|
||||||
@@ -104,18 +111,16 @@ class SpaceInvadersEnv(gym.Env):
|
|||||||
terminated = False
|
terminated = False
|
||||||
truncated = False
|
truncated = False
|
||||||
|
|
||||||
|
self.current_step += 1
|
||||||
|
if self.current_step >= self.max_steps:
|
||||||
|
truncated = True
|
||||||
|
|
||||||
nearest = self._nearest_enemy()
|
nearest = self._nearest_enemy()
|
||||||
if nearest is not None:
|
if nearest is not None:
|
||||||
enemy_x = (nearest.center_x - self.player.center_x) / float(self.width)
|
enemy_x = (nearest.center_x - self.player.center_x) / float(self.width)
|
||||||
else:
|
else:
|
||||||
enemy_x = 2.0
|
enemy_x = 2.0
|
||||||
|
|
||||||
prev_bullet = self._nearest_enemy_bullet()
|
|
||||||
if prev_bullet is not None:
|
|
||||||
prev_bx = (prev_bullet.center_x - self.player.center_x) / float(self.width)
|
|
||||||
else:
|
|
||||||
prev_bx = 2.0
|
|
||||||
|
|
||||||
prev_x = self.player.center_x
|
prev_x = self.player.center_x
|
||||||
current_action_dir = 0
|
current_action_dir = 0
|
||||||
|
|
||||||
@@ -129,12 +134,15 @@ class SpaceInvadersEnv(gym.Env):
|
|||||||
t = time.perf_counter()
|
t = time.perf_counter()
|
||||||
if t - self.last_shot >= PLAYER_ATTACK_SPEED:
|
if t - self.last_shot >= PLAYER_ATTACK_SPEED:
|
||||||
self.last_shot = t
|
self.last_shot = t
|
||||||
|
|
||||||
b = Bullet(self.player.center_x, self.player.center_y, 1)
|
b = Bullet(self.player.center_x, self.player.center_y, 1)
|
||||||
|
|
||||||
self.bullets.append(b)
|
self.bullets.append(b)
|
||||||
|
|
||||||
if enemy_x != 2.0 and abs(enemy_x) < 0.04:
|
if enemy_x != 2.0 and abs(enemy_x) < 0.04:
|
||||||
reward += 8.0
|
reward += 0.3
|
||||||
elif enemy_x != 2.0 and abs(enemy_x) < 0.1:
|
elif enemy_x != 2.0 and abs(enemy_x) < 0.1:
|
||||||
reward += 3.0
|
reward += 0.1
|
||||||
|
|
||||||
if self.player.center_x > self.width:
|
if self.player.center_x > self.width:
|
||||||
self.player.center_x = self.width
|
self.player.center_x = self.width
|
||||||
@@ -145,8 +153,9 @@ class SpaceInvadersEnv(gym.Env):
|
|||||||
|
|
||||||
if current_action_dir != 0:
|
if current_action_dir != 0:
|
||||||
if self.last_direction != 0 and current_action_dir != self.last_direction:
|
if self.last_direction != 0 and current_action_dir != self.last_direction:
|
||||||
if self.steps_since_direction_change < 8:
|
if self.steps_since_direction_change < 3:
|
||||||
reward -= 3.0
|
reward -= 0.1
|
||||||
|
|
||||||
self.steps_since_direction_change = 0
|
self.steps_since_direction_change = 0
|
||||||
else:
|
else:
|
||||||
self.steps_since_direction_change += 1
|
self.steps_since_direction_change += 1
|
||||||
@@ -154,9 +163,9 @@ class SpaceInvadersEnv(gym.Env):
|
|||||||
|
|
||||||
if enemy_x != 2.0:
|
if enemy_x != 2.0:
|
||||||
if abs(enemy_x) < 0.03:
|
if abs(enemy_x) < 0.03:
|
||||||
reward += 3.0
|
reward += 0.1
|
||||||
elif abs(enemy_x) < 0.08:
|
elif abs(enemy_x) < 0.08:
|
||||||
reward += 1.0
|
reward += 0.05
|
||||||
|
|
||||||
for b in list(self.bullets):
|
for b in list(self.bullets):
|
||||||
b.center_y += b.direction_y * BULLET_SPEED
|
b.center_y += b.direction_y * BULLET_SPEED
|
||||||
@@ -178,7 +187,7 @@ class SpaceInvadersEnv(gym.Env):
|
|||||||
self.bullets.remove(b)
|
self.bullets.remove(b)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
pass
|
pass
|
||||||
reward += 25.0
|
reward += 1.0
|
||||||
break
|
break
|
||||||
|
|
||||||
for b in list(self.bullets):
|
for b in list(self.bullets):
|
||||||
@@ -188,14 +197,14 @@ class SpaceInvadersEnv(gym.Env):
|
|||||||
self.bullets.remove(b)
|
self.bullets.remove(b)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
pass
|
pass
|
||||||
reward -= 100.0
|
reward -= 5.0
|
||||||
terminated = True
|
terminated = True
|
||||||
|
|
||||||
if not self.enemies:
|
if not self.enemies:
|
||||||
reward += 200.0
|
reward += 10.0
|
||||||
terminated = True
|
terminated = True
|
||||||
|
|
||||||
if self.enemies and random.random() < 0.02:
|
if self.enemies and random.random() < 0.05:
|
||||||
e = random.choice(self.enemies)
|
e = random.choice(self.enemies)
|
||||||
b = Bullet(e.center_x, e.center_y, -1)
|
b = Bullet(e.center_x, e.center_y, -1)
|
||||||
self.bullets.append(b)
|
self.bullets.append(b)
|
||||||
@@ -203,17 +212,14 @@ class SpaceInvadersEnv(gym.Env):
|
|||||||
curr_bullet = self._nearest_enemy_bullet()
|
curr_bullet = self._nearest_enemy_bullet()
|
||||||
if curr_bullet is not None:
|
if curr_bullet is not None:
|
||||||
curr_bx = (curr_bullet.center_x - self.player.center_x) / float(self.width)
|
curr_bx = (curr_bullet.center_x - self.player.center_x) / float(self.width)
|
||||||
curr_by = (curr_bullet.center_y - self.player.center_y) / float(self.height)
|
|
||||||
else:
|
else:
|
||||||
curr_bx = 2.0
|
curr_bx = 2.0
|
||||||
curr_by = 2.0
|
|
||||||
|
|
||||||
if prev_bx != 2.0 and curr_bx != 2.0:
|
if self.prev_bx != 2.0 and curr_bx != 2.0:
|
||||||
if abs(curr_bx) > abs(prev_bx):
|
if abs(curr_bx) > abs(self.prev_bx):
|
||||||
reward += 0.3
|
reward += 0.02
|
||||||
|
|
||||||
if curr_bx != 2.0 and abs(curr_bx) < 0.08 and curr_by < 0.5:
|
reward -= 0.01
|
||||||
reward -= 0.3
|
|
||||||
|
|
||||||
obs = self._obs()
|
obs = self._obs()
|
||||||
self.prev_bx = curr_bx
|
self.prev_bx = curr_bx
|
||||||
Reference in New Issue
Block a user