强化学习Agent系列(二)——PyGame虚拟环境创建与Python 贪吃蛇Agent制作实战教学

文章目录

  • 一、前言
  • 二、gymnasium 简单虚拟环境创建
    • 1、gymnasium介绍
    • 2、gymnasium 贪吃蛇简单示例
  • 三、基于gymnasium创建的虚拟环境训练贪吃蛇Agent
    • 1、虚拟环境
    • 2、虚拟环境注册
    • 3、训练程序
    • 4、模型测试
  • 三、卷积虚拟环境
    • 1、卷积神经网络虚拟环境
    • 2、训练代码

一、前言

大家好,未来的开发者们请上座
随着人工智能的发展,强化学习基本会再次来到人们眼前,遂想制作一下相关的教程。强化学习第一步基本离不开虚拟环境的搭建,下面用大家耳熟能详的贪吃蛇游戏为基础,制作一个Agent,完成对这个游戏的绝杀。
万里长城第二步:用python开发贪吃蛇智能体****加粗样式

二、gymnasium 简单虚拟环境创建

1、gymnasium介绍

gymnasium(此前称为gym)是一个由 OpenAI 开发的 Python 库,用于开发和比较强化学习算法。它提供了一组丰富的环境,模拟了各种任务,包括但不限于经典的控制问题、像素级游戏、机器人模拟等。

以下是gymnasium库的一些主要特点:

  1. 环境多样性:gymnasium包含了一系列不同的环境,每个环境都有其独特的观察空间(输入)和动作空间(输出)。这些环境涵盖了从简单的文本控制任务到复杂的三维视觉任务的广泛范围。

  2. 标准化API:gymnasium库提供了一个简单且统一的API来与这些环境交互。这使得研究人员和开发人员可以轻松地用相同的代码测试和比较不同的强化学习算法。

  3. 扩展性:用户可以创建自定义环境并将其集成到gymnasium框架中,这使得库能够适应各种不同的研究需求和应用场景。

  4. 评估标准:gymnasium环境通常包括预定义的评估标准,如累积回报或任务完成时间,这有助于在不同算法间进行公平的比较。

  5. 社区支持:由于gymnasium是由OpenAI推出并得到了强化学习社区的广泛支持,因此有大量的教程、论坛讨论和第三方资源可供学习和参考。

  6. 可视化和监控:gymnasium提供了工具来可视化智能体的性能,并允许监控和记录实验过程,便于分析和调试。

使用gymnasium的基本步骤通常包括:

导入gymnasium库。
创建一个环境实例。
初始化环境。
在一个循环中,根据当前观察值选择动作,执行动作,并接收环境的反馈(新的观察值、奖励、完成状态等)。
结束实验并关闭环境。

2、gymnasium 贪吃蛇简单示例

下面是贪吃蛇虚拟环境的一个简单的示例
在本次示例中,暂未进行任何训练。一切行为主要是从状态空间中随机抽取一个动作并执行。
下面是gymnasium创建的虚拟环境的三个核心函数介绍

  • reset(): 这个函数用于重置环境到初始状态,并返回初始状态的观测值。在开始每个新的episode时,通常会调用这个函数来初始化环境。

  • step(action): 这个函数用于让Agent在环境中执行一个动作(action),并返回四个值:观测值(observation),奖励(reward),是否终止(done),以及额外信息(info)。Agent根据环境返回的信息来决定下一步的动作。

  • render(): 这个函数用于在屏幕上渲染当前环境的状态,通常用于可视化环境以便观察Agent的行为。不是所有的环境都支持渲染,具体取决于环境的实现。

下面是具体的示例代码

import timeimport pygame
import sys
import random
import numpy as np
import gymnasium as gym
class SnakeEnv(gym.Env):def __init__(self):super().__init__()# 初始化Pygamepygame.init()# 屏幕宽高self.SCREEN_WIDTH=240self.SCREEN_HEIGHT=240#蛇的方块大小self.snakeCell=10# 创建窗口self.screen = pygame.display.set_mode((self.SCREEN_WIDTH,self.SCREEN_HEIGHT))pygame.display.set_caption('Snake_Game')self.action_space=gym.spaces.Discrete(4) #动作空间为4self.observation_space=gym.spaces.Box(low=0,high=7,shape=(self.SCREEN_WIDTH,self.SCREEN_HEIGHT),dtype=np.uint8)# 重启def reset(self):"""重置蛇和食物的位置"""# 蛇的初始位置self.snake_head=[100,50]self.snake_body=[[100,50],[100-self.snakeCell,50],[100-self.snakeCell*2,50]]self.len=3# 食物的初始位置self.food_pos=[random.randint(1,self.SCREEN_WIDTH//10-1)*10,random.randint(1,self.SCREEN_HEIGHT//10-1)*10]return self._get_observation()# 根据当前状态 和action 执行动作def step(self,action):# 定义动作到方向的映射directionDict={'LEFT':[1,0],'RIGHT':[-1,0],'UP':[0,-1],'DOWN':[0,1]}action_to_direction = {0: "UP",1: "DOWN",2: "LEFT",3: "RIGHT"}directionTarget=action_to_direction[action]nextPosDelay=np.array(directionDict[directionTarget])*self.snakeCell #加的位置self.snake_head=list(np.array(self.snake_body[0])+nextPosDelay)if self.snake_head in self.snake_body:return self._get_observation(), 0, True, False, {}self.snake_body.insert(0,self.snake_head)# 如果是吃到食物,就重新刷新果子,同时长度 +1if self.food_pos == self.snake_head:self.food_pos = [random.randrange(1, (self.SCREEN_WIDTH // 10)) * 10,random.randrange(1, (self.SCREEN_HEIGHT // 10)) * 10]self.len+=1# 弹出while self.len<len(self.snake_body):self.snake_body.pop()# 奖励reward,done=self._get_reward()truncated = Truereturn self._get_observation(), reward, truncated, done, {}# 渲染def render(self,mode="human"):# 实现可视化screen = self.screen# 颜色定义WHITE= (255,255,255)GREEN = (0,255,0)RED = (255,0,0)# 清空屏幕screen.fill(WHITE)# 画蛇和食物for pos in self.snake_body:pygame.draw.rect(screen,GREEN,pygame.Rect(pos[0],pos[1],self.snakeCell,self.snakeCell))pygame.draw.rect(screen,RED,pygame.Rect(self.food_pos[0],self.food_pos[1],self.snakeCell,self.snakeCell))pygame.display.update()# 获取奖励def _get_reward(self):# 计算奖励reward = 0done = False# 检查蛇是否吃到食物if self.snake_head:reward+=10# 检查蛇是否撞到墙壁或自身head=self.snake_headif head[0]<0 or head[0]>self.SCREEN_WIDTH-10 or head[1]<0 or head[1]>self.SCREEN_HEIGHT-10:reward = -10done = Truereturn reward,done# 获取当前观察空间def _get_observation(self):# 获取窗口内容作为观察值observation = pygame.display.get_surface()# 将观察值调整为指定的宽度和高度# observation = pygame.transform.scale(observation, (self.SCREEN_WIDTH, self.SCREEN_HEIGHT))return observationdef main():snakeEnv=Snake()snakeEnv.reset()done=Falsewhile not done:# 获取事件for event in pygame.event.get():# 处理退出事件if event.type == pygame.QUIT:pygame.quit()done = True# 从动作空间随机获取一个动作action= snakeEnv.action_space.sample()screen, reward, truncated, done,_=snakeEnv.step(action)snakeEnv.render()time.sleep(0.03)
if __name__=="__main__":main()

程序运行截图:
在这里插入图片描述

三、基于gymnasium创建的虚拟环境训练贪吃蛇Agent

在上一步中,你已经创建出你要的虚拟环境了,现在让我们在这个创建好的环境中进行训练吧!

1、虚拟环境

SnakeEnv2.py

import timeimport pygame
import sys
import random
import numpy as np
import gymnasium as gym
from typing import Optionalclass SnakeEnv(gym.Env):metadata = {"render_modes": ["human", "rgb_array"],"render_fps": 30,}def __init__(self, render_mode="human"):super().__init__()# 初始化Pygamepygame.init()# 屏幕宽高self.SCREEN_WIDTH = 100self.SCREEN_HEIGHT = 100# 蛇的方块大小self.snakeCell = 10# 游戏速度self.speed = 12self.clock = pygame.time.Clock()# 创建窗口self.screen = pygame.display.set_mode((self.SCREEN_WIDTH, self.SCREEN_HEIGHT))pygame.display.set_caption('Snake_Game')self.render_mode = render_modeself.action_space = gym.spaces.Discrete(4)  # 动作空间为4self.observation_space = gym.spaces.Box(low=-1, high=1, shape=(self.SCREEN_WIDTH, self.SCREEN_HEIGHT),dtype=np.float32)# 初始化蛇和食物的位置等属性# ...self.num_timesepts = 0# 重启def reset(self, seed: Optional[int] = None, options: Optional[dict] = None):"""重置蛇和食物的位置"""super().reset(seed=seed)self.curStep = 0  # 步数# 蛇的初始位置self.snake_next = [60, 50]self.snake_body = [[60, 50], [60 - self.snakeCell, 50], [60 - self.snakeCell * 2, 50]]self.len = 3# 食物的初始位置self.food_pos = [random.randint(1, self.SCREEN_WIDTH // 10 - 1) * 10,random.randint(1, self.SCREEN_HEIGHT // 10 - 1) * 10]info = {}#return self._get_observation(), info# 根据当前状态 和action 执行动作def step(self, action):self.num_timesepts += 1  # 步骤统计加1# 定义动作到方向的映射directionDict = {'LEFT': [-1, 0], 'RIGHT': [1, 0], 'UP': [0, -1], 'DOWN': [0, 1]}action_to_direction = {0: "UP",1: "DOWN",2: "LEFT",3: "RIGHT"}directionTarget = action_to_direction[action]nextPosDelay = np.array(directionDict[directionTarget]) * self.snakeCell  # 加的位置# 输入action 获取到 snake_next 下一步self.snake_next = list(np.array(self.snake_body[0]) + nextPosDelay)if self.snake_next == self.snake_body[1]:return self._get_observation(), -0.5, False, False, {}# 奖励reward, terminated = self._get_reward()# 如果是吃到食物,就重新刷新果子,同时长度 +1if self.food_pos == self.snake_next:self.food_pos = [random.randrange(1, (self.SCREEN_WIDTH // 10-1)) * 10,random.randrange(1, (self.SCREEN_HEIGHT // 10-1)) * 10]self.len += 1truncated = Falseinfo = {}# if self.render_mode == "human":if self.render_mode == "human" and self.num_timesepts % 5000 > 4000 and self.num_timesepts > 10000:self.render()for event in pygame.event.get():if event == pygame.QUIT:pygame.quit()sys.exit()if not terminated:self.snake_body.insert(0, self.snake_next)# 弹出while self.len < len(self.snake_body):self.snake_body.pop()return self._get_observation(), reward,terminated,  truncated,{}# 渲染def render(self):# 实现可视化screen = self.screen# 颜色定义WHITE = (255, 255, 255)GREEN = (0, 255, 0)RED = (255, 0, 0)# 清空屏幕screen.fill(WHITE)# 画蛇和食物snakecolor = np.linspace(0.9, 0.5, len(self.snake_body), dtype=np.float32)for i in range(len(self.snake_body)):pos = self.snake_body[len(self.snake_body) - i - 1]color = [int(round(component * snakecolor[i])) for component in GREEN]pygame.draw.rect(screen, color, pygame.Rect(pos[0], pos[1], self.snakeCell, self.snakeCell))pygame.draw.rect(screen, RED, pygame.Rect(self.food_pos[0], self.food_pos[1], self.snakeCell, self.snakeCell))pygame.display.update()self.clock.tick(self.speed)def GetDic(self,p1,p2):return np.linalg.norm(np.array(p1) - np.array(p2))# 获取奖励def _get_reward(self):# 计算奖励self.curStep += 1  # 步数reward = 0terminated = Falseflag=0# 正向激励# 检查蛇是否吃到食物 ,吃到食物,就开始猛猛奖励if self.snake_next == self.food_pos:reward += 500 + pow(5, self.len)self.curStep = 0#print(reward)# 负向激励# 检查蛇是否撞到墙壁或自身,游戏结束就负向奖励head = self.snake_nextif head[0] < 0 or head[0] > self.SCREEN_WIDTH-10  or head[1] < 0 or head[1] > self.SCREEN_HEIGHT-10  or self.snake_next in self.snake_body or self.curStep>500:reward -= 100 / self.lenterminated = Trueself.curStep = 0# 摸鱼步数超过一定值就开始负向奖励if self.curStep > 100 * self.len:reward -=  1 / self.len# 中向激励if  self.GetDic(self.snake_next,self.food_pos)< self.GetDic(self.snake_body[0],self.food_pos):reward += 2 / self.len * (self.SCREEN_WIDTH-self.GetDic(self.snake_body[0],self.food_pos)) /self.SCREEN_WIDTH # No upper limit might enable the agent to master shorter scenario faster and more firmly.else:reward -= 1 / self.len#print(reward * 0.3)if reward<0:#print(reward * 0.2)pass#print(reward* 0.2)return reward * 0.2, terminated# 获取当前观察空间def _get_observation(self):# 返回观察空间,也就是一个二维数组obs = np.zeros((self.SCREEN_WIDTH, self.SCREEN_HEIGHT), dtype=np.float32)obs[tuple(np.transpose(self.snake_body))] = np.linspace(0.8, 0.2, len(self.snake_body), dtype=np.float32)obs[tuple(self.snake_body[0])] = 1.0obs[tuple(self.food_pos)] = -1.0return obsdef main():snakeEnv = SnakeEnv()snakeEnv.reset()done = Falsewhile not done:# 获取事件for event in pygame.event.get():# 处理退出事件if event.type == pygame.QUIT:pygame.quit()done = True# 从动作空间随机获取一个动作action = snakeEnv.action_space.sample()screen, reward, truncated, done, _ = snakeEnv.step(action)snakeEnv.render()if __name__ == "__main__":main()

2、虚拟环境注册

打开当前项目的site-packages
在这里插入图片描述
找到gymnasium
将其SnakeEnv2.py放置如下,并在init.py中添加调用注册函数
在这里插入图片描述
到这里就注册完毕,可以进行训练了

3、训练程序

snake_train.py 具体代码如下

# 1、导入必要的库并创建环境:
import gymnasium as gym
from stable_baselines3 import PPO
from stable_baselines3.common.callbacks import CheckpointCallback
import os
import sys
# Linear scheduler
def linear_schedule(initial_value, final_value=0.0):if isinstance(initial_value, str):initial_value = float(initial_value)final_value = float(final_value)assert (initial_value > 0.0)def scheduler(progress):return final_value + progress * (initial_value - final_value)return scheduler# 2、创建环境,例如 CartPole
env = gym.make('SnakeEnv-test',render_mode="human")
# 3、创建 PPO 模型并指定环境:
lr_schedule = linear_schedule(2.5e-2, 2.5e-6)
clip_range_schedule = linear_schedule(0.15, 0.025)model = PPO("MlpPolicy", env, verbose=1, device="cuda",n_steps=2048,batch_size=512,n_epochs=4,gamma=0.94,learning_rate=lr_schedule,clip_range=clip_range_schedule,)
# 4、训练模型:# Set the save directory
num=1
save_dir="trained_models_mlp"
while True:save_dir = "trained_models_mlp_{}".format(num)if not os.path.exists(save_dir):os.mkdir(save_dir)breakelse:num +=1checkpoint_interval = 30000  # checkpoint_interval * num_envs = total_steps_per_checkpoint
checkpoint_callback = CheckpointCallback(save_freq=checkpoint_interval, save_path=save_dir, name_prefix="ppo_snake")# Writing the training logs from stdout to a file
original_stdout = sys.stdout
log_file_path = os.path.join(save_dir, "training_log.txt")
print('开始训练'+save_dir)model.learn(total_timesteps=int(200000),callback=[checkpoint_callback]
)# Restore stdout
sys.stdout = original_stdout# Save the final model
model.save(os.path.join(save_dir, "ppo_snake_final.zip"))

4、模型测试

对训练好的模型进行测试,可以用如下代码

import time
import random
from sb3_contrib import MaskablePPO
from stable_baselines3 import PPO
from snakecnn23 import SnakeEnv
import pygameMODEL_PATH=r'H:\AILab\RL\Snaker2\trained_models_cnn\ppo_snake_final'
# Load the trained model
model = MaskablePPO.load(MODEL_PATH)snakeEnv = SnakeEnv()
for i in range(10):obs,info=snakeEnv.reset()terminated = Falsewhile not terminated:# 获取事件for event in pygame.event.get():if event == pygame.QUIT:pygame.quit()# 从动作空间随机获取一个动作action ,_=  model.predict(obs, action_masks=snakeEnv.get_action_mask())prev_mask = snakeEnv.get_action_mask()action_value=int(action.item())obs, reward,  terminated, truncated, _ = snakeEnv.step(action_value)snakeEnv.render()

三、卷积虚拟环境

上面的是基于多层感知机,上限有限,可能效果不是很好,可以对其进行一点点改进
核心是修改
self.observation_space = gym.spaces.Box(low=0, high=255, shape=(self.SCREEN_WIDTH, self.SCREEN_HEIGHT,3),dtype=np.uint8)
和 _get_observation() 观察空间
改完这两个其他的基本不用变

1、卷积神经网络虚拟环境

import timeimport pygame
import sys
import random
import numpy as np
import gymnasium as gym
from typing import Optionalclass SnakeEnv(gym.Env):metadata = {"render_modes": ["human", "rgb_array"],"render_fps": 30,}def __init__(self, render_mode="human"):super().__init__()# 初始化Pygamepygame.init()# 屏幕宽高self.SCREEN_WIDTH = 84self.SCREEN_HEIGHT = 84# 蛇的方块大小self.snakeCell = 7# 游戏速度self.speed = 12self.clock = pygame.time.Clock()# 创建窗口self.screen = pygame.display.set_mode((self.SCREEN_WIDTH, self.SCREEN_HEIGHT))pygame.display.set_caption('Snake_Game')self.render_mode = render_modeself.action_space = gym.spaces.Discrete(4)  # 动作空间为4self.observation_space = gym.spaces.Box(low=0, high=255, shape=(self.SCREEN_WIDTH, self.SCREEN_HEIGHT,3),dtype=np.uint8)# 初始化蛇和食物的位置等属性# ...self.num_timesepts = 0# 重启def reset(self, seed: Optional[int] = None, options: Optional[dict] = None):"""重置蛇和食物的位置"""super().reset(seed=seed)self.curStep = 0  # 步数# 蛇的初始位置self.snake_next = [60, 50]self.snake_body = [[60, 50], [60 - self.snakeCell, 50], [60 - self.snakeCell * 2, 50]]self.len = 3# 食物的初始位置self.food_pos = [random.randint(1, self.SCREEN_WIDTH // 10 - 1) * 10,random.randint(1, self.SCREEN_HEIGHT // 10 - 1) * 10]info = {}#return self._get_observation(), info# 根据当前状态 和action 执行动作def step(self, action):self.num_timesepts += 1  # 步骤统计加1# 定义动作到方向的映射directionDict = {'LEFT': [-1, 0], 'RIGHT': [1, 0], 'UP': [0, -1], 'DOWN': [0, 1]}action_to_direction = {0: "UP",1: "DOWN",2: "LEFT",3: "RIGHT"}directionTarget = action_to_direction[action]nextPosDelay = np.array(directionDict[directionTarget]) * self.snakeCell  # 加的位置# 输入action 获取到 snake_next 下一步self.snake_next = list(np.array(self.snake_body[0]) + nextPosDelay)if self.snake_next == self.snake_body[1]:return self._get_observation(), -0.5, False, False, {}# 奖励reward, terminated = self._get_reward(action)# 如果是吃到食物,就重新刷新果子,同时长度 +1if self.food_pos == self.snake_next:self.food_pos = [random.randrange(1, (self.SCREEN_WIDTH // 10-1)) * 10,random.randrange(1, (self.SCREEN_HEIGHT // 10-1)) * 10]self.len += 1truncated = Falseinfo = {}# if self.render_mode == "human":if self.render_mode == "human" and self.num_timesepts % 5000 > 4500 and self.num_timesepts > 10000:self.render()for event in pygame.event.get():if event == pygame.QUIT:pygame.quit()sys.exit()if not terminated:self.snake_body.insert(0, self.snake_next)# 弹出while self.len < len(self.snake_body):self.snake_body.pop()return self._get_observation(), reward,terminated,  truncated,{}# 渲染def render(self):# 实现可视化screen = self.screen# 颜色定义WHITE = (255, 255, 255)GREEN = (0, 255, 0)RED = (255, 0, 0)# 清空屏幕screen.fill(WHITE)# 画蛇和食物snakecolor = np.linspace(0.9, 0.5, len(self.snake_body), dtype=np.float32)for i in range(len(self.snake_body)):pos = self.snake_body[len(self.snake_body) - i - 1]color = [int(round(component * snakecolor[i])) for component in GREEN]pygame.draw.rect(screen, color, pygame.Rect(pos[0], pos[1], self.snakeCell, self.snakeCell))pygame.draw.rect(screen, RED, pygame.Rect(self.food_pos[0], self.food_pos[1], self.snakeCell, self.snakeCell))pygame.display.update()self.clock.tick(self.speed)def GetDic(self,p1,p2):return np.linalg.norm(np.array(p1) - np.array(p2))# 获取奖励def _get_reward(self,action):# 计算奖励self.curStep += 1  # 步数reward = 0terminated = Falseflag=0# 正向激励# 检查蛇是否吃到食物 ,吃到食物,就开始猛猛奖励if self.snake_next == self.food_pos:reward += 400 + pow(5, self.len)self.curStep = 0# print(reward)#print(action)# print(self.snake_body,self.food_pos,self.snake_next)# 负向激励# 检查蛇是否撞到墙壁或自身,游戏结束就负向奖励head = self.snake_nextif head[0] < 0 or head[0] > self.SCREEN_WIDTH-10  or head[1] < 0 or head[1] > self.SCREEN_HEIGHT-10  or self.snake_next in self.snake_body or self.curStep>500:reward -= 200 / self.lenterminated = Trueself.curStep = 0# 摸鱼步数超过一定值就开始负向奖励if self.curStep > 250 * self.len:reward -=  1 / self.len# 中向激励if  self.GetDic(self.snake_next,self.food_pos)< self.GetDic(self.snake_body[0],self.food_pos):reward += 4 / self.len * (self.SCREEN_WIDTH-self.GetDic(self.snake_next,self.food_pos)) /self.SCREEN_WIDTH # No upper limit might enable the agent to master shorter scenario faster and more firmly.elif self.curStep>50 and self.GetDic(self.snake_next,self.food_pos)>= self.GetDic(self.snake_body[0],self.food_pos):reward -= 2 / self.len * self.GetDic(self.snake_next,self.food_pos) /self.SCREEN_WIDTH#print(reward * 0.3)if reward<0:#print(reward * 0.2)pass# print(reward* 0.2)return reward * 0.2, terminated# 获取当前观察空间def _get_observation(self):obs = np.zeros((self.SCREEN_WIDTH//self.snakeCell, self.SCREEN_HEIGHT//self.snakeCell), dtype=np.uint8)# Set the snake body to gray with linearly decreasing intensity from head to tail.newsnake=np.array(self.snake_body)//7obs[tuple(np.transpose(newsnake))] = np.linspace(200, 50, len(newsnake), dtype=np.uint8)# Stack single layer into 3-channel-image.obs = np.stack((obs, obs, obs), axis=-1)# Set the snake head to green and the tail to blueobs[tuple(newsnake[0])] = [0, 255, 0]obs[tuple(newsnake[-1])] = [255, 0, 0]# Set the food to redobs[np.array(self.food_pos)//7] = [0, 0, 255]# Enlarge the observation to 84x84obs = np.repeat(np.repeat(obs, self.snakeCell, axis=0), self.snakeCell, axis=1)return obsdef main():snakeEnv = SnakeEnv()snakeEnv.reset()done = Falsewhile not done:# 获取事件for event in pygame.event.get():# 处理退出事件if event.type == pygame.QUIT:pygame.quit()done = True# 从动作空间随机获取一个动作action = snakeEnv.action_space.sample()screen, reward, truncated, done, _ = snakeEnv.step(action)snakeEnv.render()if __name__ == "__main__":main()

2、训练代码

核心是修改算法名,之前用MlpPolicy,现在改为CnnPolicy
其余不变

# 1、导入必要的库并创建环境:
import gymnasium as gym
from stable_baselines3 import PPO
from stable_baselines3.common.callbacks import CheckpointCallback
import os
import sys
# Linear schedulerfrom stable_baselines3 import PPO
def linear_schedule(initial_value, final_value=0.0):if isinstance(initial_value, str):initial_value = float(initial_value)final_value = float(final_value)assert (initial_value > 0.0)def scheduler(progress):return final_value + progress * (initial_value - final_value)return scheduler
LOG_DIR = "logs"
os.makedirs(LOG_DIR, exist_ok=True)
# 2、创建环境,例如 CartPole
env = gym.make('SnakeEnvcnn-test',render_mode="human")
# 3、创建 PPO 模型并指定环境:
lr_schedule = linear_schedule(2.5e-3, 2.5e-6)
clip_range_schedule = linear_schedule(0.15, 0.025)model = PPO(  "CnnPolicy",env,device="cuda",verbose=1,n_steps=2048,batch_size=512,n_epochs=4,gamma=0.94,learning_rate=lr_schedule,clip_range=clip_range_schedule,tensorboard_log=LOG_DIR)
# 4、训练模型:# Set the save directory
num=1
save_dir="trained_models_cnn"
while True:save_dir = "trained_models_cnn_{}".format(num)if not os.path.exists(save_dir):os.mkdir(save_dir)breakelse:num +=1checkpoint_interval = 30000  # checkpoint_interval * num_envs = total_steps_per_checkpoint
checkpoint_callback = CheckpointCallback(save_freq=checkpoint_interval, save_path=save_dir, name_prefix="ppo_snake")# Writing the training logs from stdout to a file
original_stdout = sys.stdout
log_file_path = os.path.join(save_dir, "training_log.txt")
print('开始训练'+save_dir)model.learn(total_timesteps=int(200000),callback=[checkpoint_callback]
)# Restore stdout
sys.stdout = original_stdout# Save the final model
model.save(os.path.join(save_dir, "ppo_snake_final.zip"))# 5、测试训练好的模型:
obs = env.reset()for i in range(1000):action, _states = model.predict(obs, deterministic=True)observation, reward, terminated, truncated, info = env.step(action)env.render()if terminated:obs = env.reset()
env.close()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://xiahunao.cn/news/2815622.html

如若内容造成侵权/违法违规/事实不符,请联系瞎胡闹网进行投诉反馈,一经查实,立即删除!

相关文章

【MySQL】主从同步原理、分库分表

主从同步原理 1. 主从同步原理 MySQL 经常先把命令拷入硬盘的日志&#xff0c;再执行日志的命令&#xff0c;这样的好处&#xff1a; 日志的位置固定&#xff0c;拷入硬盘的开销不大&#xff1b;将命令先准备好&#xff0c;而不是边读边执行&#xff0c;性能更好&#xff0c;…

GIS之深度学习01:检测电脑是否包含英伟达GPU

GPU&#xff08;Graphics processing unit&#xff09;&#xff0c;中文全称图形处理器&#xff0c;我们听说的更多的CPU全称是central processing unit&#xff0c;中央处理器。研究深度学习和神经网络大都离不开GPU&#xff0c;在GPU的加持下&#xff0c;我们可以更快的获得模…

Golang使用Swag搭建api文档

1. 简介 Gin是Golang目前最为常用的Web框架之一。 公司项目验收需要API接口设计说明书&#xff08;Golang后端服务基于Gin框架编写&#xff09;&#xff0c;编写任务自然就落到了我们研发人员身上。 项目经理提供了文档模板&#xff0c;让我们参考模板来手动编写&#xff0c;要…

幻兽帕鲁(1.5.0)可视化管理工具(0.5.7 docker版)安装教程

文章目录 局域网帕鲁服务器部署教程帕鲁服务可视化工具安装配置服务器地址&#xff08;可跳过&#xff09;使用工具管理面板 1.5.0服务端RCON错误1.5.0服务端无法启动RCON端口 解决方法第一步&#xff1a;PalWorldSettings.ini配置第二步&#xff1a;修改PalServer.sh配置 局域…

如何一步一步地优化LVGL的丝滑度

经过一番周折将LVGL移植到了STM32F407单片机上&#xff0c;底层驱动的LCD是st7789&#xff0c;移植时的条件和环境如下&#xff1a; ●LVGL用的是单缓冲&#xff0c;一次刷新10行&#xff1b; ●刷新函数用的是最原始的一个一个打点的方式&#xff1b; ●ST7789底层发送数据用的…

抖音橱窗怎么关闭?

路径&#xff1a;【抖音APP-我-电商带货-全部工具-账号管理-权限与账户-关闭电商权限】关闭橱窗带货权限。 0保、原0粉达人关闭电商权限后&#xff0c;后续均需要满足页面提示要求才可以进行电商权限的开通&#xff0c;建议慎重操作。

2024年阿里云优惠券领取及使用教程_无门槛优惠券

阿里云优惠代金券领取入口&#xff0c;阿里云服务器优惠代金券、域名代金券&#xff0c;在领券中心可以领取当前最新可用的满减代金券&#xff0c;阿里云百科aliyunbaike.com分享阿里云服务器代金券、领券中心、域名代金券领取、代金券查询及使用方法&#xff1a; 阿里云优惠券…

Xinstall助力社交产品,轻松实现用户关系链归因

在如今的社交网络时代&#xff0c;人与人之间的联系变得越来越紧密。社交产品作为连接人与人之间的桥梁&#xff0c;其重要性不言而喻。然而&#xff0c;如何让用户在使用社交产品时更快速地建立起社交联系&#xff0c;一直是社交产品开发者们关注的焦点。 Xinstall作为一款专…

智能销售数据大屏:决胜市场的数字利器

在数字化浪潮席卷全球的今天&#xff0c;数据已经成为企业决策的核心要素。尤其对于销售团队来说&#xff0c;如何快速、准确地把握市场动态&#xff0c;分析客户行为&#xff0c;成为决定胜负的关键。而智能销售数据大屏&#xff0c;正是这样一款能够帮助企业洞察市场脉络、决…

如何用GPT高效地处理文本、文献查阅、PPT编辑、编程、绘图和论文写作?

原文链接&#xff1a;如何用GPT高效地处理文本、文献查阅、PPT编辑、编程、绘图和论文写作?https://mp.weixin.qq.com/s?__bizMzUzNTczMDMxMg&mid2247594986&idx4&sn970f9ba75998f2dd9fa5707d1611a6cc&chksmfa82320dcdf5bb1bdf58c20686d4eb209770e68253ed90d…

StarRocks实战——滴滴OLAP的技术实践与发展方向

原文大佬的这篇StarRocks实践文章整体写的很深入&#xff0c;介绍了StarRocks数仓架构设计、物化视图加速实时看板、全局字典精确去重等内容&#xff0c;这里直接摘抄下来用作学习和知识沉淀。 目录 一、背景介绍 1.1 滴滴OLAP的发展历程 1.2 OLAP引擎存在的痛点 1.2.1 运维…

高通 AI Hub 上手指南

文章介绍 2月26日&#xff0c;高通在2024年世界移动通信大会&#xff08;MWC2024&#xff09;上发布高通AI Hub&#xff0c; AI Hub 简化了AI 模型部署到边缘设备的过程。可以利用AI-hub云端托管 Qualcomm 设备上&#xff0c;在几分钟内完成模型的优化、验证和部署。本文以Pyto…

老卫带你学---leetcode刷题(103. 二叉树的锯齿形层序遍历)

103. 二叉树的锯齿形层序遍历 问题 给你二叉树的根节点 root &#xff0c;返回其节点值的 锯齿形层序遍历 。&#xff08;即先从左往右&#xff0c;再从右往左进行下一层遍历&#xff0c;以此类推&#xff0c;层与层之间交替进行&#xff09;。 示例 1&#xff1a; 输入&a…

优先级队列介绍和模拟实现

个人主页&#xff1a;Lei宝啊 愿所有美好如期而遇 Container就是适配器&#xff0c;也就是说这个优先级队列的底层就使用vector&#xff0c;当然&#xff0c;我们也可以使用deque来适配&#xff0c;但是对于优先级队列来说&#xff0c;效率是低于vector的&#xff0c;因为优先…

【前端素材】推荐优质在线家具电商Bazu平台模板(附源码)

一、需求分析 1、系统定义 家具电商平台是指专门销售家具产品的在线电子商务平台。这些平台专注于家具类商品的销售和服务&#xff0c;为消费者提供方便快捷的购买体验。 2、功能需求 家具电商平台是指专门销售家具产品的在线电子商务平台。这些平台专注于家具类商品的销售…

从此告别手忙脚乱!职场高手管理多微信的技巧

在现在职场&#xff0c;微信已经成为了必不可少的沟通工具之一。然而&#xff0c;随着工作联系人的增多&#xff0c;管理多个微信账号可能会变得一团糟。今天&#xff0c;我们就来分享一些职场高手们管理多个微信账号的技巧&#xff0c;让你从此告别手忙脚乱的状态&#xff01;…

搜维尔科技:OptiTrack 提供了性能最佳的动作捕捉平台

OptiTrack 动画 我们的 Prime 系列相机和 Motive 软件相结合&#xff0c;产生了世界上最大的捕获量、最精确的 3D 数据和有史以来最高的相机数量。OptiTrack 提供了性能最佳的动作捕捉平台&#xff0c;具有易于使用的制作工作流程以及运行世界上最大舞台所需的深度。 无与伦比…

Python 画 箱线图

Python 画 箱线图 flyfish 箱线图 其他名字 盒须图 / 箱形图 横向用正态分布看 垂直看 pandas画 import pandas as pdimport seaborn as sns import matplotlib.pyplot as plt import pandas as pddf pd.read_csv(sh300.csv) print("原始数据") print(df.he…

ARTS Week 17

Algorithm 本周的算法题为 989. 数组形式的整数加法 整数的 数组形式 num 是按照从左到右的顺序表示其数字的数组。 例如&#xff0c;对于 num 1321 &#xff0c;数组形式是 [1,3,2,1] 。 给定 num &#xff0c;整数的 数组形式 &#xff0c;和整数 k &#xff0c;返回 整数 n…

【算法历练】动态规划副本—路径问题

&#x1f3ac;慕斯主页&#xff1a;修仙—别有洞天 ♈️今日夜电波&#xff1a;宙でおやすみ 1:02━━━━━━️&#x1f49f;──────── 2:45 &#x1f504; ◀️ ⏸ ▶️ ☰ &#…