Skip to content

Commit

Permalink
Merge pull request #121 from ntucllab/reward-to-recommendation
Browse files Browse the repository at this point in the history
move reward from History to Recommendation
  • Loading branch information
ianlini authored Nov 11, 2016
2 parents e94f022 + ac5a137 commit 7154629
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 18 deletions.
45 changes: 32 additions & 13 deletions striatum/bandit/tests/base_bandit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,18 @@
)


def recommendations_to_rewards(recommendations):
if not hasattr(recommendations, '__iter__'):
recommendations = (recommendations,)
rewards = {}
for rec in recommendations:
if rec.reward is None:
continue
rewards[rec.action.id] = rec.reward
return rewards



class BaseBanditTest(object):
# pylint: disable=protected-access

Expand Down Expand Up @@ -91,8 +103,9 @@ def test_update_reward(self):
history_id, recommendations = policy.get_action(context, 1)
rewards = {recommendations[0].action.id: 1.}
policy.reward(history_id, rewards)
self.assertEqual(
policy._history_storage.get_history(history_id).rewards, rewards)
history_rewards = recommendations_to_rewards(
policy._history_storage.get_history(history_id).recommendations)
self.assertEqual(history_rewards, rewards)

def test_delay_reward(self):
policy = self.policy
Expand All @@ -113,11 +126,13 @@ def test_delay_reward(self):
self.assertDictEqual(
policy._history_storage.get_unrewarded_history(history_id2).context,
context2)
self.assertDictEqual(
policy._history_storage.get_history(history_id1).rewards,
rewards)
self.assertIsNone(
policy._history_storage.get_unrewarded_history(history_id2).rewards)
history_rewards1 = recommendations_to_rewards(
policy._history_storage.get_history(history_id1).recommendations)
self.assertDictEqual(history_rewards1, rewards)
history_rewards2 = recommendations_to_rewards(
policy._history_storage.get_unrewarded_history(history_id2)
.recommendations)
self.assertDictEqual(history_rewards2, {})

def test_reward_order_descending(self):
policy = self.policy
Expand All @@ -132,10 +147,13 @@ def test_reward_order_descending(self):
context1)
self.assertDictEqual(
policy._history_storage.get_history(history_id2).context, context2)
self.assertIsNone(
policy._history_storage.get_unrewarded_history(history_id1).rewards)
self.assertDictEqual(
policy._history_storage.get_history(history_id2).rewards, rewards)
history_rewards1 = recommendations_to_rewards(
policy._history_storage.get_unrewarded_history(history_id1)
.recommendations)
self.assertDictEqual(history_rewards1, {})
history_rewards2 = recommendations_to_rewards(
policy._history_storage.get_history(history_id2).recommendations)
self.assertDictEqual(history_rewards2, rewards)

def test_update_action(self):
action = self.actions[1]
Expand Down Expand Up @@ -187,5 +205,6 @@ def test_remove_and_get_action_and_reward(self):

rewards = {recommendations[0].action.id: 1.}
policy.reward(history_id, rewards)
self.assertEqual(
policy._history_storage.get_history(history_id).rewards, rewards)
history_rewards = recommendations_to_rewards(
policy._history_storage.get_history(history_id).recommendations)
self.assertEqual(history_rewards, rewards)
2 changes: 1 addition & 1 deletion striatum/storage/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@
"""

from .model import MemoryModelStorage
from .history import MemoryHistoryStorage
from .history import MemoryHistoryStorage, History
from .action import MemoryActionStorage, Action
from .recommendation import Recommendation
14 changes: 11 additions & 3 deletions striatum/storage/history.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,11 @@ class History(object):
"""

def __init__(self, history_id, context, recommendations, created_at,
rewards=None, rewarded_at=None):
rewarded_at=None):
self.history_id = history_id
self.context = context
self.recommendations = recommendations
self.created_at = created_at
self.rewards = rewards
self.rewarded_at = rewarded_at

def update_reward(self, rewards, rewarded_at):
Expand All @@ -35,7 +34,16 @@ def update_reward(self, rewards, rewarded_at):
rewards : {float, dict of float, None}
rewarded_at : {datetime, None}
"""
self.rewards = rewards
if not hasattr(self.recommendations, '__iter__'):
recommendations = (self.recommendations,)
else:
recommendations = self.recommendations

for rec in recommendations:
try:
rec.reward = rewards[rec.action.id]
except KeyError:
pass
self.rewarded_at = rewarded_at


Expand Down
4 changes: 3 additions & 1 deletion striatum/storage/recommendation.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@ class Recommendation(object):
score: float
"""

def __init__(self, action, estimated_reward, uncertainty, score):
def __init__(self, action, estimated_reward, uncertainty, score,
reward=None):
self.action = action
self.estimated_reward = estimated_reward
self.uncertainty = uncertainty
self.score = score
self.reward = reward

0 comments on commit 7154629

Please sign in to comment.