2 releases
0.0.6 | Sep 19, 2023 |
---|---|
0.0.5 | Jan 29, 2022 |
#887 in Science
45 downloads per month
Used in 2 crates
115KB
2K
SLoC
Asynchronous trainer with parallel sampling processes.
The code might look like below.
fn train() {
let agent_configs: Vec<_> = vec![agent_config()];
let env_config_train = env_config(name);
let env_config_eval = env_config(name).eval();
let replay_buffer_config = load_replay_buffer_config(model_dir.as_str())?;
let step_proc_config = SimpleStepProcessorConfig::default();
let actor_man_config = ActorManagerConfig::default();
let async_trainer_config = load_async_trainer_config(model_dir.as_str())?;
let mut recorder = TensorboardRecorder::new(model_dir);
let mut evaluator = Evaluator::new(&env_config_eval, 0, 1)?;
// Shared flag to stop actor threads
let stop = Arc::new(Mutex::new(false));
// Creates channels
let (item_s, item_r) = unbounded(); // items pushed to replay buffer
let (model_s, model_r) = unbounded(); // model_info
// guard for initialization of envs in multiple threads
let guard_init_env = Arc::new(Mutex::new(true));
// Actor manager and async trainer
let mut actors = ActorManager::build(
&actor_man_config,
&agent_configs,
&env_config_train,
&step_proc_config,
item_s,
model_r,
stop.clone(),
);
let mut trainer = AsyncTrainer::build(
&async_trainer_config,
&agent_config,
&env_config_eval,
&replay_buffer_config,
item_r,
model_s,
stop.clone(),
);
// Set the number of threads
tch::set_num_threads(1);
// Starts sampling and training
actors.run(guard_init_env.clone());
let stats = trainer.train(&mut recorder, &mut evaluator, guard_init_env);
println!("Stats of async trainer");
println!("{}", stats.fmt());
let stats = actors.stop_and_join();
println!("Stats of generated samples in actors");
println!("{}", actor_stats_fmt(&stats));
}
Training process consists of the following two components:
ActorManager
managesActor
s, each of which runs a thread for interactingAgent
andEnv
and taking samples. Those samples will be sent to the replay buffer inAsyncTrainer
.AsyncTrainer
is responsible for training of an agent. It also runs a thread for pushing samples fromActorManager
into a replay buffer.
The Agent
must implement SyncModel
trait in order to synchronize the model of
the agent in Actor
with the trained agent in AsyncTrainer
. The trait has
the ability to import and export the information of the model as
SyncModel
::ModelInfo
.
The Agent
in AsyncTrainer
is responsible for training, typically with a GPU,
while the Agent
s in Actor
s in ActorManager
is responsible for sampling
using CPU.
Both AsyncTrainer
and ActorManager
are running in the same machine and
communicate by channels.
Dependencies
~13–21MB
~281K SLoC