diff --git a/mathesar/api/db/viewsets/schemas.py b/mathesar/api/db/viewsets/schemas.py index 320e3c4b80..fb7dc64f24 100644 --- a/mathesar/api/db/viewsets/schemas.py +++ b/mathesar/api/db/viewsets/schemas.py @@ -24,15 +24,18 @@ class SchemaViewSet(AccessViewSetMixin, viewsets.GenericViewSet, ListModelMixin, def get_queryset(self): qs = Schema.objects.all().order_by('-created_at') + connection_id = self.request.query_params.get('connection_id') + if connection_id: + qs = qs.filter(database=connection_id) return self.access_policy.scope_viewset_queryset(self.request, qs) def create(self, request): serializer = SchemaSerializer(data=request.data, context={'request': request}) serializer.is_valid(raise_exception=True) - database_name = serializer.validated_data['database'].name + connection_id = serializer.validated_data['database'].id schema = create_schema_and_object( serializer.validated_data['name'], - database_name, + connection_id, comment=serializer.validated_data.get('description') ) serializer = SchemaSerializer(schema, context={'request': request}) diff --git a/mathesar/api/serializers/schemas.py b/mathesar/api/serializers/schemas.py index 10c5de682a..f345362ed5 100644 --- a/mathesar/api/serializers/schemas.py +++ b/mathesar/api/serializers/schemas.py @@ -1,4 +1,4 @@ -from rest_access_policy import PermittedSlugRelatedField +from rest_access_policy import PermittedPkRelatedField from rest_framework import serializers from db.identifiers import is_identifier_too_long @@ -15,11 +15,10 @@ class SchemaSerializer(MathesarErrorMessageMixin, serializers.HyperlinkedModelSerializer): name = serializers.CharField() # Restrict access to databases with create access. - # Unlike PermittedPkRelatedField this field uses a slug instead of an id # Refer https://rsinger86.github.io/drf-access-policy/policy_reuse/ - database = PermittedSlugRelatedField( + connection_id = PermittedPkRelatedField( + source='database', access_policy=DatabaseAccessPolicy, - slug_field='name', queryset=Database.current_objects.all() ) description = serializers.CharField( @@ -31,7 +30,7 @@ class SchemaSerializer(MathesarErrorMessageMixin, serializers.HyperlinkedModelSe class Meta: model = Schema fields = [ - 'id', 'name', 'database', 'has_dependents', 'description', + 'id', 'name', 'connection_id', 'has_dependents', 'description', 'num_tables', 'num_queries' ] diff --git a/mathesar/tests/api/test_schema_api.py b/mathesar/tests/api/test_schema_api.py index 381f8c6bc0..82166a926b 100644 --- a/mathesar/tests/api/test_schema_api.py +++ b/mathesar/tests/api/test_schema_api.py @@ -14,12 +14,13 @@ def check_schema_response( schema, schema_name, test_db_name, + connection_id, schema_description=None, check_schema_objects=True ): assert response_schema['id'] == schema.id assert response_schema['name'] == schema_name - assert response_schema['database'] == test_db_name + assert response_schema['connection_id'] == connection_id assert response_schema['description'] == schema_description assert 'has_dependents' in response_schema if check_schema_objects: @@ -61,6 +62,7 @@ def test_schema_list(request, patent_schema, create_schema, MOD_engine_cache, cl patent_schema, patent_schema.name, patent_schema.database.name, + patent_schema.database.id ) @@ -119,7 +121,7 @@ def test_schema_list_filter(client, create_db_schema, FUN_create_dj_db, MOD_engi check_schema_response( MOD_engine_cache, response_schema, schema, schema.name, - schema.database.name, check_schema_objects=False + schema.database.name, schema.database.id, check_schema_objects=False ) @@ -169,7 +171,7 @@ def test_schema_detail(create_patents_table, client, test_db_name, MOD_engine_ca assert response.status_code == 200 check_schema_response( MOD_engine_cache, - response_schema, table.schema, table.schema.name, test_db_name + response_schema, table.schema, table.schema.name, test_db_name, table.schema.database.id ) @@ -204,7 +206,7 @@ def test_schema_sort_by_name(create_schema, client, MOD_engine_cache): for comparison_tuple in comparison_tuples: check_schema_response( MOD_engine_cache, comparison_tuple[0], comparison_tuple[1], comparison_tuple[1].name, - comparison_tuple[1].database.name + comparison_tuple[1].database.name, comparison_tuple[1].database.id ) sort_field = "name" response = client.get(f'/api/db/v0/schemas/?sort_by={sort_field}') @@ -214,7 +216,7 @@ def test_schema_sort_by_name(create_schema, client, MOD_engine_cache): for comparison_tuple in comparison_tuples: check_schema_response( MOD_engine_cache, comparison_tuple[0], comparison_tuple[1], comparison_tuple[1].name, - comparison_tuple[1].database.name + comparison_tuple[1].database.name, comparison_tuple[1].database.id ) @@ -250,7 +252,7 @@ def test_schema_sort_by_id(create_schema, client, MOD_engine_cache): check_schema_response( MOD_engine_cache, comparison_tuple[0], comparison_tuple[1], comparison_tuple[1].name, - comparison_tuple[1].database.name + comparison_tuple[1].database.name, comparison_tuple[1].database.id ) response = client.get('/api/db/v0/schemas/?sort_by=id') @@ -261,20 +263,20 @@ def test_schema_sort_by_id(create_schema, client, MOD_engine_cache): check_schema_response( MOD_engine_cache, comparison_tuple[0], comparison_tuple[1], comparison_tuple[1].name, - comparison_tuple[1].database.name + comparison_tuple[1].database.name, comparison_tuple[1].database.id ) def test_schema_create_by_superuser(client, FUN_create_dj_db, MOD_engine_cache): db_name = "some_db1" - FUN_create_dj_db(db_name) + database = FUN_create_dj_db(db_name) schema_count_before = Schema.objects.count() schema_name = 'Test Schema' data = { 'name': schema_name, - 'database': db_name + 'connection_id': database.id } response = client.post('/api/db/v0/schemas/', data=data) response_schema = response.json() @@ -289,19 +291,20 @@ def test_schema_create_by_superuser(client, FUN_create_dj_db, MOD_engine_cache): schema, schema_name, db_name, + database.id, check_schema_objects=0 ) def test_schema_create_by_superuser_too_long_name(client, FUN_create_dj_db): db_name = "some_db1" - FUN_create_dj_db(db_name) + database = FUN_create_dj_db(db_name) schema_count_before = Schema.objects.count() very_long_string = ''.join(map(str, range(50))) schema_name = 'very_long_identifier_' + very_long_string data = { 'name': schema_name, - 'database': db_name + 'connection_id': database.id } response = client.post('/api/db/v0/schemas/', data=data) assert response.status_code == 400 @@ -318,7 +321,7 @@ def test_schema_create_by_db_manager(client_bob, user_bob, FUN_create_dj_db, get schema_name = 'Test Schema' data = { 'name': schema_name, - 'database': db_name + 'connection_id': database.id } response = client_bob.post('/api/db/v0/schemas/', data=data) assert response.status_code == 400 @@ -337,7 +340,7 @@ def test_schema_create_by_db_editor(client_bob, user_bob, FUN_create_dj_db, get_ schema_name = 'Test Schema' data = { 'name': schema_name, - 'database': db_name + 'connection_id': database.id } response = client_bob.post('/api/db/v0/schemas/', data=data) assert response.status_code == 400 @@ -353,11 +356,11 @@ def test_schema_create_multiple_existing_roles(client_bob, user_bob, FUN_create_ schema_name = 'Test Schema' data = { 'name': schema_name, - 'database': database_with_viewer_access.name + 'connection_id': database_with_viewer_access.id } response = client_bob.post('/api/db/v0/schemas/', data=data) assert response.status_code == 400 - data['database'] = database_with_manager_access.name + data['connection_id'] = database_with_manager_access.id response = client_bob.post('/api/db/v0/schemas/', data=data) assert response.status_code == 201 @@ -365,7 +368,7 @@ def test_schema_create_multiple_existing_roles(client_bob, user_bob, FUN_create_ @pytest.mark.skip("Faulty DB handling assumptions; invalid") def test_schema_create_description(client, FUN_create_dj_db, MOD_engine_cache): db_name = "some_db2" - FUN_create_dj_db(db_name) + database = FUN_create_dj_db(db_name) schema_count_before = Schema.objects.count() @@ -373,7 +376,7 @@ def test_schema_create_description(client, FUN_create_dj_db, MOD_engine_cache): description = 'blah blah blah' data = { 'name': schema_name, - 'database': db_name, + 'connection_id': database.id, 'description': description, } response = client.post('/api/db/v0/schemas/', data=data) @@ -415,7 +418,14 @@ def test_schema_partial_update(create_schema, client, test_db_name, MOD_engine_c response_schema = response.json() assert response.status_code == 200 - check_schema_response(MOD_engine_cache, response_schema, schema, new_schema_name, test_db_name) + check_schema_response( + MOD_engine_cache, + response_schema, + schema, + new_schema_name, + test_db_name, + schema.database.id + ) schema = Schema.objects.get(oid=schema.oid) assert schema.name == new_schema_name diff --git a/mathesar/utils/schemas.py b/mathesar/utils/schemas.py index 7bd523add6..1d12c0060b 100644 --- a/mathesar/utils/schemas.py +++ b/mathesar/utils/schemas.py @@ -7,11 +7,12 @@ from mathesar.models.base import Schema, Database -def create_schema_and_object(name, database, comment=None): +def create_schema_and_object(name, connection_id, comment=None): try: - database_model = Database.objects.get(name=database) + database_model = Database.objects.get(id=connection_id) + database_name = database_model.name except ObjectDoesNotExist: - raise ValidationError({"database": f"Database '{database}' not found"}) + raise ValidationError({"database": f"Database '{database_name}' not found"}) engine = create_mathesar_engine(database_model)