Permute high-level observation based on given permutation.
This function applies a permutation to the rack-related parts of the observation while keeping other parts unchanged. Used for data augmentation during training to improve generalization.
21def permute_high_level_observation(permutation: np.array, obs: np.array) -> np.array:
22 """!
23 @brief Permute high-level observation based on given permutation
24
25 This function applies a permutation to the rack-related parts of the observation
26 while keeping other parts unchanged. Used for data augmentation during training
27 to improve generalization.
28
29 @param permutation The permutation array to apply (size 10 for racks)
30 @param obs The observation array to permute (size 40)
31 @return The permuted observation array
32 """
33
34 permuted_obs = np.zeros(40)
35 for i in range(10):
36 pos = permutation[i] * 3
37 permuted_obs[i] = obs[i]
38 permuted_obs[10 + i*3] = obs[10 + pos]
39 permuted_obs[11 + i*3] = obs[11 + pos]
40 permuted_obs[12 + i*3] = obs[12 + pos]
41
42 return permuted_obs