diff --git a/src/lti/launch_params.py b/src/lti/launch_params.py index 46290c3..5e87f1a 100644 --- a/src/lti/launch_params.py +++ b/src/lti/launch_params.py @@ -160,7 +160,7 @@ def __init__(self, *args, **kwargs): # now verify we only got valid launch params for k in self.keys(): - if not valid_param(k): + if not self.valid_param(k): raise InvalidLaunchParamError(k) # enforce some defaults @@ -181,11 +181,14 @@ def _param_value(self, param): else: return self._params[param] + def valid_param(self, param): + return valid_param(param) + def __len__(self): return len(self._params) def __getitem__(self, item): - if not valid_param(item): + if not self.valid_param(item): raise KeyError("{} is not a valid launch param".format(item)) try: return self._param_value(item) @@ -194,7 +197,7 @@ def __getitem__(self, item): raise KeyError(item) def __setitem__(self, key, value): - if not valid_param(key): + if not self.valid_param(key): raise InvalidLaunchParamError(key) if key in LAUNCH_PARAMS_IS_LIST: if isinstance(value, list): diff --git a/src/lti/tool_provider.py b/src/lti/tool_provider.py index 8956293..4eaf480 100644 --- a/src/lti/tool_provider.py +++ b/src/lti/tool_provider.py @@ -21,11 +21,12 @@ class ToolProvider(ToolBase): ''' Implements the LTI Tool Provider. ''' + launch_params_class = LaunchParams @classmethod def from_unpacked_request(cls, secret, params, url, headers): - launch_params = LaunchParams(params) + launch_params = cls.launch_params_class(params) if 'oauth_consumer_key' not in launch_params: raise InvalidLTIRequestError("oauth_consumer_key not found!") diff --git a/tests/test_tool_provider.py b/tests/test_tool_provider.py index 3485999..d7678bb 100644 --- a/tests/test_tool_provider.py +++ b/tests/test_tool_provider.py @@ -19,13 +19,21 @@ def create_tp(key=None, secret=None, lp=None, launch_url=None, launch_headers=None, tp_class=ToolProvider): key = key or generate_client_id() secret = secret or generate_token() - launch_params = LaunchParams() + launch_params = tp_class.launch_params_class() if lp is not None: launch_params.update(lp) launch_url = launch_url or "http://example.edu" launch_headers = launch_headers or {} return tp_class(key, secret, launch_params, launch_url, launch_headers) +class CustomLaunchParams(LaunchParams): + def valid_param(self, param): + result = super(CustomLaunchParams, self).valid_param(param) + return result or param in ['basiclti_submit', 'launch_url'] + +class CustomToolProvider(ToolProvider): + launch_params_class = CustomLaunchParams + class TestToolProvider(unittest.TestCase): def test_constructor(self): @@ -215,6 +223,29 @@ def test_last_outcome_request(self): tp.outcome_requests = ['foo','bar'] self.assertEqual(tp.last_outcome_request(), 'bar') + def test_custom_launch_params(self): + key = generate_client_id() + secret = generate_token() + lp = { + 'lti_version': 'foo', + 'lti_message_type': 'bar', + 'resource_link_id': 123, + 'launch_url': 'more_foo', + 'basiclti_submit': 'more_bar' + } + launch_url = 'http://example.edu/foo/bar' + launch_headers = {'Content-Type': 'baz'} + tp = create_tp(key, secret, lp, launch_url, launch_headers, tp_class=CustomToolProvider) + + with patch.object(SignatureOnlyEndpoint, 'validate_request') as mv: + mv.return_value = True, None # Tuple of valid, request + self.assertTrue(tp.is_valid_request(Mock())) + call_url, call_method, call_params, call_headers = mv.call_args[0] + self.assertEqual(call_url, launch_url) + self.assertEqual(call_method, 'POST') + self.assertEqual(call_params, lp) + self.assertEqual(call_headers, launch_headers) + # mock the django.shortcuts import to allow testing mock = Mock() mock.shortcuts.redirect.return_value = 'foo'