Le renforcement d'apprentissage (RL) s'attaque aux problèmes complexes, des véhicules autonomes aux modèles de langage sophistiqués. Les agents RL apprennent par l'apprentissage du renforcement de la rétroaction humaine (RLHF), adaptant leurs réponses en fonction de l'apport humain. Alors que des frameworks Python comme Keras et Tensorflow sont établis, Pytorch et Pytorch Lightning dominent les nouveaux projets.
Torchrl, une bibliothèque open source, simplifie le développement de RL avec Pytorch. Ce didacticiel montre la configuration de Torchrl, les composants de base et la construction d'un agent RL de base. Nous explorerons des algorithmes prédéfinis comme l'optimisation de la politique proximale (PPO) et les techniques de journalisation et de surveillance essentielles.
Cette section vous guide dans l'installation et l'utilisation de Torchrl.
Avant d'installer Torchrl, assurez-vous d'avoir:
Installez les conditions préalables:
!pip install torch tensordict gymnasium==0.29.1 pygame
Installez Torchrl à l'aide de PIP. Un environnement conda est recommandé pour les ordinateurs personnels ou les serveurs.
!pip install torchrl
Testez votre installation en important torchrl
dans un shell ou un cahier Python. Utilisez check_env_specs()
pour vérifier la compatibilité de l'environnement (par exemple, cartpole):
import torchrl from torchrl.envs import GymEnv from torchrl.envs.utils import check_env_specs check_env_specs(GymEnv("CartPole-v1"))
Une installation réussie s'affiche:
<code>[torchrl][INFO] check_env_specs succeeded!</code>
Avant la création d'agents, examinons les éléments principaux de Torchrl.
Torchrl fournit une API cohérente pour divers environnements, enveloppez des fonctions spécifiques à l'environnement en emballages standard. Cela simplifie l'interaction:
Créez un environnement de gymnase en utilisant GymEnv
:
env = GymEnv("CartPole-v1")
Améliorer les environnements avec des modules complémentaires (par exemple, des compteurs de pas) en utilisant TransformedEnv
:
from torchrl.envs import GymEnv, StepCounter, TransformedEnv env = TransformedEnv(GymEnv("CartPole-v1"), StepCounter())
La normalisation est obtenue avec ObservationNorm
:
from torchrl.envs import Compose base_env = GymEnv('CartPole-v1', device=device) env = TransformedEnv( base_env, Compose( ObservationNorm(in_keys=["observation"]), StepCounter() ) )
Les transformations multiples sont combinées en utilisant Compose
.
L'agent utilise une politique pour sélectionner des actions en fonction de l'état de l'environnement, visant à maximiser les récompenses cumulatives.
Une stratégie aléatoire simple est créée en utilisant RandomPolicy
:
!pip install torch tensordict gymnasium==0.29.1 pygame
Cette section démontre la construction d'un agent RL simple.
Importer les packages nécessaires:
!pip install torchrl
Nous utiliserons l'environnement de cartpole:
import torchrl from torchrl.envs import GymEnv from torchrl.envs.utils import check_env_specs check_env_specs(GymEnv("CartPole-v1"))
Définir les hyperparamètres:
<code>[torchrl][INFO] check_env_specs succeeded!</code>
Définir une politique de réseau neuronal simple:
env = GymEnv("CartPole-v1")
Créer un collecteur de données et un tampon de relecture:
from torchrl.envs import GymEnv, StepCounter, TransformedEnv env = TransformedEnv(GymEnv("CartPole-v1"), StepCounter())
Définir les modules de formation:
from torchrl.envs import Compose base_env = GymEnv('CartPole-v1', device=device) env = TransformedEnv( base_env, Compose( ObservationNorm(in_keys=["observation"]), StepCounter() ) )
Implémentez la boucle de formation (simplifiée pour la concision):
import torchrl import torch from tensordict import TensorDict from torchrl.data.tensor_specs import Bounded action_spec = Bounded(-torch.ones(1), torch.ones(1)) actor = torchrl.envs.utils.RandomPolicy(action_spec=action_spec) td = actor(TensorDict({}, batch_size=[])) print(td.get("action"))
Ajouter une évaluation et une journalisation à la boucle de formation (simplifiée):
import time import matplotlib.pyplot as plt from torchrl.envs import GymEnv, StepCounter, TransformedEnv from tensordict.nn import TensorDictModule as TensorDict, TensorDictSequential as Seq from torchrl.modules import EGreedyModule, MLP, QValueModule from torchrl.objectives import DQNLoss, SoftUpdate from torchrl.collectors import SyncDataCollector from torchrl.data import LazyTensorStorage, ReplayBuffer from torch.optim import Adam from torchrl._utils import logger as torchrl_logger
Imprimer le temps de formation et les résultats de l'intrigue:
env = TransformedEnv(GymEnv("CartPole-v1"), StepCounter()) torch.manual_seed(0) env.set_seed(0)
(l'implémentation complète du DQN est disponible dans le manuel de données référencé.)
Torchrl propose des algorithmes pré-construits (DQN, DDPG, SAC, PPO, etc.). Cette section démontre l'utilisation de PPO.
Importer les modules nécessaires:
INIT_RAND_STEPS = 5000 FRAMES_PER_BATCH = 100 OPTIM_STEPS = 10 EPS_0 = 0.5 BUFFER_LEN = 100_000 ALPHA = 0.05 TARGET_UPDATE_EPS = 0.95 REPLAY_BUFFER_SAMPLE = 128 LOG_EVERY = 1000 MLP_SIZE = 64
Définir les hyperparamètres:
value_mlp = MLP(out_features=env.action_spec.shape[-1], num_cells=[MLP_SIZE, MLP_SIZE]) value_net = TensorDict(value_mlp, in_keys=["observation"], out_keys=["action_value"]) policy = Seq(value_net, QValueModule(spec=env.action_spec)) exploration_module = EGreedyModule( env.action_spec, annealing_num_steps=BUFFER_LEN, eps_init=EPS_0 ) policy_explore = Seq(policy, exploration_module)
(L'implémentation de PPO restante, y compris les définitions de réseau, la collecte de données, la fonction de perte, l'optimisation et la boucle de formation, suit une structure similaire à la réponse originale mais est omise ici par concision. Reportez-vous à la réponse d'origine pour le code complet.)
Surveiller les progrès de la formation à l'aide de Tensorboard:
collector = SyncDataCollector( env, policy_explore, frames_per_batch=FRAMES_PER_BATCH, total_frames=-1, init_random_frames=INIT_RAND_STEPS, ) rb = ReplayBuffer(storage=LazyTensorStorage(BUFFER_LEN))
Visualisez avec: tensorboard --logdir="training_logs"
Le débogage consiste à vérifier les spécifications de l'environnement:
loss = DQNLoss(value_network=policy, action_space=env.action_spec, delay_value=True) optim = Adam(loss.parameters(), lr=ALPHA) updater = SoftUpdate(loss, eps=TARGET_UPDATE_EPS)
Échantillons d'observations et d'actions:
total_count = 0 total_episodes = 0 t0 = time.time() success_steps = [] for i, data in enumerate(collector): rb.extend(data) # ... (training steps, similar to the original response) ...
visualiser les performances de l'agent en rendant une vidéo (nécessite torchvision
et av
):
# ... (training steps) ... if total_count > 0 and total_count % LOG_EVERY == 0: torchrl_logger.info(f"Successful steps: {max_length}, episodes: {total_episodes}") if max_length > 475: print("TRAINING COMPLETE") break
Ce tutoriel a fourni une introduction complète à Torchrl, présentant ses capacités via des exemples DQN et PPO. Expérimentez avec différents environnements et algorithmes pour améliorer encore vos compétences RL. Les ressources référencées offrent des opportunités d'apprentissage supplémentaires.
Ce qui précède est le contenu détaillé de. pour plus d'informations, suivez d'autres articles connexes sur le site Web de PHP en chinois!