forked from Learning-and-Intelligent-Systems/predicators
-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathonline_nsrt_learning_approach.py
More file actions
183 lines (163 loc) · 8.31 KB
/
online_nsrt_learning_approach.py
File metadata and controls
183 lines (163 loc) · 8.31 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
"""An NSRT learning approach that collects and learns from data online.
Example command:
python predicators/main.py --approach online_nsrt_learning --seed 0 \
--env cover \
--explorer glib \
--max_initial_demos 1 \
--num_train_tasks 1000 \
--num_test_tasks 10 \
--max_num_steps_interaction_request 10 \
--min_data_for_nsrt 10
"""
from __future__ import annotations
import itertools
import logging
from collections import defaultdict
from typing import DefaultDict, FrozenSet, List, Sequence, Set
from gym.spaces import Box
from predicators import utils
from predicators.approaches.nsrt_learning_approach import NSRTLearningApproach
from predicators.explorers import create_explorer
from predicators.settings import CFG
from predicators.structs import Dataset, GroundAtom, InteractionRequest, \
InteractionResult, LiftedAtom, LowLevelTrajectory, Object, \
ParameterizedOption, Predicate, Task, Type
class OnlineNSRTLearningApproach(NSRTLearningApproach):
"""OnlineNSRTLearningApproach implementation."""
def __init__(self, initial_predicates: Set[Predicate],
initial_options: Set[ParameterizedOption], types: Set[Type],
action_space: Box, train_tasks: List[Task]) -> None:
super().__init__(initial_predicates, initial_options, types,
action_space, train_tasks)
self._dataset = Dataset([])
self._online_learning_cycle = 0
# Used for the novelty score function (in GLIB).
self._novelty_counts: DefaultDict[FrozenSet[LiftedAtom], int] = \
defaultdict(int)
@classmethod
def get_name(cls) -> str:
return "online_nsrt_learning"
def get_interaction_requests(self) -> List[InteractionRequest]:
# Explore in the train tasks. The number of train tasks that are
# explored at each timestep is a hyperparameter. The train task
# is randomly selected.
# Create the explorer. Note that greedy lookahead is not yet supported.
preds = self._get_current_predicates()
explorer = create_explorer(CFG.explorer,
preds,
self._initial_options,
self._types,
self._action_space,
self._train_tasks,
self._get_current_nsrts(),
self._option_model,
babble_predicates=preds,
atom_score_fn=self._score_atoms_novelty)
# Create the interaction requests.
requests = []
for _ in range(CFG.online_nsrt_learning_requests_per_cycle):
# Select a random task (with replacement).
task_idx = self._rng.choice(len(self._train_tasks))
if CFG.env == "behavior":
# Set up the explorer policy and termination function.
act_plan = explorer.get_exploration_plan_strategy(
task_idx, CFG.timeout)
# Create the interaction request.
req = InteractionRequest(train_task_idx=task_idx,
act_policy=None,
query_policy=lambda s: None,
termination_function=None,
act_plan=act_plan)
else:
# Set up the explorer policy and termination function.
policy, termination_function = explorer.get_exploration_strategy(
task_idx, CFG.timeout)
# Create the interaction request.
req = InteractionRequest(train_task_idx=task_idx,
act_policy=policy,
query_policy=lambda s: None,
termination_function=termination_function)
requests.append(req)
return requests
def learn_from_offline_dataset(self, dataset: Dataset) -> None:
# Update the dataset with the offline data.
for traj in dataset.trajectories:
self._update_dataset(traj)
super().learn_from_offline_dataset(dataset)
def learn_from_interaction_results(
self, results: Sequence[InteractionResult]) -> None:
# Add the new data to the cumulative dataset.
for result in results:
traj = LowLevelTrajectory(result.states, result.actions)
self._update_dataset(traj)
# Re-learn the NSRTs.
self._learn_nsrts(self._dataset.trajectories,
self._online_learning_cycle)
# Advance the online learning cycle.
self._online_learning_cycle += 1
def _update_dataset(self, trajectory: LowLevelTrajectory) -> None:
"""Add a new low-level trajectory to the dataset."""
# Update the dataset.
self._dataset.append(trajectory)
# Update the atom counts for the novelty score function (for GLIB).
# NOTE: this assumes that predicates are not changing because we are
# just running the predicate classifiers once per state for efficiency.
assert not CFG.excluded_predicates # make sure not predicate learning
preds = self._get_current_predicates()
assert preds == self._initial_predicates
for state in trajectory.states:
atoms = utils.abstract(state, preds)
for s in range(CFG.glib_min_goal_size, CFG.glib_max_goal_size + 1):
for atom_tup in itertools.combinations(atoms, s):
atom_set = set(atom_tup)
can_atom_set = self._get_canonical_lifted_atoms(atom_set)
self._novelty_counts[can_atom_set] += 1
logging.debug(f"Novelty counts: {self._novelty_counts}")
def _score_atoms_novelty(self, atoms: Set[GroundAtom]) -> float:
"""Score the novelty of a ground atom set, with higher better.
Score based on the number of times that this atom set has been seen in
the data, with object identities ignored (i.e., this is lifted).
Assumes that the size of the atom set is between CFG.glib_min_goal_size
and CFG.glib_max_goal_size (inclusive).
"""
assert CFG.glib_min_goal_size <= len(atoms) <= CFG.glib_max_goal_size
canonical_atoms = self._get_canonical_lifted_atoms(atoms)
# Note minus sign: less frequent is better.
count = self._novelty_counts[canonical_atoms]
# Once some goal has been seen online_learning_max_novelty_count
# number of times, it is no longer considered "novel" and, for example,
# won't be babbled by the GLIB explorer anymore.
if count > CFG.online_learning_max_novelty_count:
return -float("inf")
return -count
@staticmethod
def _get_canonical_lifted_atoms(
atoms: Set[GroundAtom]) -> FrozenSet[LiftedAtom]:
"""Create a canonical lifted atoms set.
This is a helper for novelty scoring for GLIB.
This is an efficient approximation of what we really care about, which
is whether two atom sets unify. It's an approximation because there are
tricky cases where the sorting procedure is ambiguous.
"""
# Create a "signature" for each object, which will be used to break
# ties when sorting based on predicates alone is not enough.
objs = {o for a in atoms for o in a.objects}
obj_to_sig = {
o: tuple(sorted(a.predicate for a in atoms if o in a.objects))
for o in objs
}
# Sort the atom set based first on predicates, then based on object
# signature.
key = lambda a: (a.predicate, tuple(obj_to_sig[o] for o in a.objects))
sorted_atom_set = sorted(atoms, key=key)
# Replace the objects with variables in order.
sorted_objs: List[Object] = []
for atom in sorted_atom_set:
for obj in atom.objects:
if obj not in sorted_objs:
sorted_objs.append(obj)
variables = utils.create_new_variables([o.type for o in sorted_objs])
sub = dict(zip(sorted_objs, variables))
# Lift the atoms.
lifted_atoms = frozenset(a.lift(sub) for a in sorted_atom_set)
return lifted_atoms