Source code for farmgym.v2.gymUnion
from gym.spaces.space import Space
import numpy as np
[docs]class Union(Space):
"""
A tuple (i.e., product) of simpler spaces
Example usage:
self.observation_space = spaces.Tuple((spaces.Discrete(2), spaces.Discrete(3)))
"""
def __init__(self, spaces):
self.spaces = spaces
for space in spaces:
assert isinstance(space, Space), "Elements of the tuple must be instances of gym.Space"
super(Union, self).__init__(None, None)
[docs] def seed(self, seed=None):
[space.seed(seed) for space in self.spaces]
[docs] def sample(self):
n = np.random.randint(len(self.spaces))
return n, self.spaces[n].sample()
[docs] def contains(self, x):
return any(space.contains(x) for space in self.spaces)
def __repr__(self):
return "Union(" + ", ".join([str(s) for s in self.spaces]) + ")"
[docs] def to_jsonable(self, sample_n):
# serialize as list-repr of union of vectors
return []
# return [space.to_jsonable([sample[i] for sample in sample_n]) \ for i, space in enumerate(self.spaces)]
[docs] def from_jsonable(self, sample_n):
return []
# return [sample for sample in zip(*[space.from_jsonable(sample_n[i]) for i, space in enumerate(self.spaces)])]
def __getitem__(self, index):
return self.spaces[index]
def __len__(self):
return len(self.spaces)
def __eq__(self, other):
return isinstance(other, Union) and self.spaces == other.spaces
[docs]class MultiUnion(Space):
"""
A tuple (i.e., product) of simpler spaces
Example usage:
self.observation_space = spaces.Tuple((spaces.Discrete(2), spaces.Discrete(3)))
"""
def __init__(self, spaces, maxnonzero=np.infty):
self.spaces = spaces
self.maxnonzero = maxnonzero
for space in spaces:
assert isinstance(space, Space), "Elements of the tuple must be instances of gym.Space"
super(MultiUnion, self).__init__(None, None)
[docs] def seed(self, seed=None):
[space.seed(seed) for space in self.spaces]
[docs] def sample(self):
m = np.random.randint(min(self.maxnonzero + 1, len(self.spaces)))
indexes = list(range(len(self.spaces)))
sampled_indexes = []
for j in range(m):
n = np.random.choice(indexes)
indexes.remove(n)
sampled_indexes.append(n)
samples = []
for n in sampled_indexes:
samples.append((n, self.spaces[n].sample()))
return samples
[docs] def contains(self, x):
if len(x) > self.maxnonzero:
return False
# print("SPACES",self.spaces)
# print("X",x)
for xx in x:
contains = []
for space in self.spaces:
# print("xx",xx,"space",space)
try:
if space.contains(xx):
contains.append(True)
break
except:
pass
if contains == []:
return False
# if not any(space.contains(xx) for space in self.spaces):
# return False
return True
def __repr__(self):
s = "MultiUnion" + (("[" + str(self.maxnonzero) + "]") if self.maxnonzero < np.infty else "")
return s + "(" + ", ".join([str(s) for s in self.spaces]) + ")"
[docs] def to_jsonable(self, sample_n):
# serialize as list-repr of union of vectors
return []
# return [space.to_jsonable([sample[i] for sample in sample_n]) \ for i, space in enumerate(self.spaces)]
[docs] def from_jsonable(self, sample_n):
return []
# return [sample for sample in zip(*[space.from_jsonable(sample_n[i]) for i, space in enumerate(self.spaces)])]
def __getitem__(self, index):
return self.spaces[index]
def __len__(self):
return len(self.spaces)
def __eq__(self, other):
return isinstance(other, MultiUnion) and self.spaces == other.spaces