Main function for PPO training and evaluation of the Beluga Challenge.
14def main():
15 """!
16 @brief Main function for PPO training and evaluation of the Beluga Challenge
17
18 Parses command line arguments and executes the appropriate mode:
19 - train: Train the PPO agent
20 - eval: Evaluate trained model performance
21 - problem: Evaluate agent on specific problem instances
22 """
23
24 parser = argparse.ArgumentParser(description='PPO Training and Evaluation for Beluga Challenge')
25
26
27 parser.add_argument('--mode', type=str, choices=['train', 'eval', 'problem'], default='train',
28 help='Mode: train (Training), eval (Model Evaluation), problem (Problem Evaluation)')
29
30
31 parser.add_argument('--train_old_models', action='store_true', default=True,
32 help='Load existing models (default: True)')
33 parser.add_argument('--use_permutation', action='store_true', default=False,
34 help='Use observation permutation (default: False)')
35 parser.add_argument('--n_episodes', type=int, default=10000,
36 help='Number of training episodes (default: 10000)')
37 parser.add_argument('--base_index', type=int, default=61,
38 help='Base index for problem selection (default: 61)')
39
40
41 parser.add_argument('--n_eval_episodes', type=int, default=10,
42 help='Number of evaluation episodes (default: 10)')
43 parser.add_argument('--max_steps', type=int, default=200,
44 help='Maximum steps per episode (default: 200)')
45 parser.add_argument('--plot', action='store_true', default=False,
46 help='Show plot after evaluation (default: False)')
47
48
49 parser.add_argument('--problem_path', type=str, default="problems/problem_90_s132_j137_r8_oc81_f43.json",
50 help='Path to problem for evaluation (default: problems/problem_90_s132_j137_r8_oc81_f43.json (Biggest Problem with 10 Racks))')
51 parser.add_argument('--max_problem_steps', type=int, default=20000,
52 help='Maximum steps for problem evaluation (default: 20000)')
53 parser.add_argument('--save_to_file', action='store_true', default=False,
54 help='Save results to TXT file (default: False)')
55
56 args = parser.parse_args()
57
58
59 env = Env(path="problems/", base_index=args.base_index)
60
61
62 n_actions = 8
63 batch_size = 128
64 n_epochs = 5
65 alpha = 0.0005
66 N = 1024
67 ppo_agent = PPOAgent(n_actions=n_actions, batch_size=batch_size, alpha=alpha,
68 n_epochs=n_epochs, input_dims=40, policy_clip=0.2, N=N, model_name="ppo")
69
70
71 trainer = Trainer(env=env, ppo_agent=ppo_agent, debug=False)
72
73
74 if args.mode == 'train':
75 print(f"Starting training with {args.n_episodes} episodes...")
76 trainer.train(n_episodes=args.n_episodes, N=10, max_steps_per_episode=args.max_steps,
77 train_on_old_models=args.train_old_models, use_permutation=args.use_permutation,
78 start_learn_after=250)
79
80 elif args.mode == 'eval':
81 print(f"Starting model evaluation with {args.n_eval_episodes} episodes...")
82 trainer.evaluateModel(n_eval_episodes=args.n_eval_episodes,
83 max_steps_per_episode=args.max_steps, plot=args.plot)
84
85 elif args.mode == 'problem':
86 print(f"Evaluating problem: {args.problem_path}")
87 trainer.evaluateProblem(args.problem_path, max_steps=args.max_problem_steps, save_to_file=args.save_to_file)
88