Skip to content

Commit

Permalink
Merge branch 'develop' into connections_rpc
Browse files Browse the repository at this point in the history
  • Loading branch information
seancolsen authored Dec 12, 2023
2 parents 24c7604 + 5e3ed40 commit fa4bcb0
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 28 deletions.
7 changes: 5 additions & 2 deletions mathesar/api/db/viewsets/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,18 @@ class SchemaViewSet(AccessViewSetMixin, viewsets.GenericViewSet, ListModelMixin,

def get_queryset(self):
qs = Schema.objects.all().order_by('-created_at')
connection_id = self.request.query_params.get('connection_id')
if connection_id:
qs = qs.filter(database=connection_id)
return self.access_policy.scope_viewset_queryset(self.request, qs)

def create(self, request):
serializer = SchemaSerializer(data=request.data, context={'request': request})
serializer.is_valid(raise_exception=True)
database_name = serializer.validated_data['database'].name
connection_id = serializer.validated_data['database'].id
schema = create_schema_and_object(
serializer.validated_data['name'],
database_name,
connection_id,
comment=serializer.validated_data.get('description')
)
serializer = SchemaSerializer(schema, context={'request': request})
Expand Down
9 changes: 4 additions & 5 deletions mathesar/api/serializers/schemas.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from rest_access_policy import PermittedSlugRelatedField
from rest_access_policy import PermittedPkRelatedField
from rest_framework import serializers

from db.identifiers import is_identifier_too_long
Expand All @@ -15,11 +15,10 @@
class SchemaSerializer(MathesarErrorMessageMixin, serializers.HyperlinkedModelSerializer):
name = serializers.CharField()
# Restrict access to databases with create access.
# Unlike PermittedPkRelatedField this field uses a slug instead of an id
# Refer https://rsinger86.github.io/drf-access-policy/policy_reuse/
database = PermittedSlugRelatedField(
connection_id = PermittedPkRelatedField(
source='database',
access_policy=DatabaseAccessPolicy,
slug_field='name',
queryset=Database.current_objects.all()
)
description = serializers.CharField(
Expand All @@ -31,7 +30,7 @@ class SchemaSerializer(MathesarErrorMessageMixin, serializers.HyperlinkedModelSe
class Meta:
model = Schema
fields = [
'id', 'name', 'database', 'has_dependents', 'description',
'id', 'name', 'connection_id', 'has_dependents', 'description',
'num_tables', 'num_queries'
]

Expand Down
46 changes: 28 additions & 18 deletions mathesar/tests/api/test_schema_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,13 @@ def check_schema_response(
schema,
schema_name,
test_db_name,
connection_id,
schema_description=None,
check_schema_objects=True
):
assert response_schema['id'] == schema.id
assert response_schema['name'] == schema_name
assert response_schema['database'] == test_db_name
assert response_schema['connection_id'] == connection_id
assert response_schema['description'] == schema_description
assert 'has_dependents' in response_schema
if check_schema_objects:
Expand Down Expand Up @@ -61,6 +62,7 @@ def test_schema_list(request, patent_schema, create_schema, MOD_engine_cache, cl
patent_schema,
patent_schema.name,
patent_schema.database.name,
patent_schema.database.id
)


Expand Down Expand Up @@ -119,7 +121,7 @@ def test_schema_list_filter(client, create_db_schema, FUN_create_dj_db, MOD_engi
check_schema_response(
MOD_engine_cache,
response_schema, schema, schema.name,
schema.database.name, check_schema_objects=False
schema.database.name, schema.database.id, check_schema_objects=False
)


Expand Down Expand Up @@ -169,7 +171,7 @@ def test_schema_detail(create_patents_table, client, test_db_name, MOD_engine_ca
assert response.status_code == 200
check_schema_response(
MOD_engine_cache,
response_schema, table.schema, table.schema.name, test_db_name
response_schema, table.schema, table.schema.name, test_db_name, table.schema.database.id
)


Expand Down Expand Up @@ -204,7 +206,7 @@ def test_schema_sort_by_name(create_schema, client, MOD_engine_cache):
for comparison_tuple in comparison_tuples:
check_schema_response(
MOD_engine_cache, comparison_tuple[0], comparison_tuple[1], comparison_tuple[1].name,
comparison_tuple[1].database.name
comparison_tuple[1].database.name, comparison_tuple[1].database.id
)
sort_field = "name"
response = client.get(f'/api/db/v0/schemas/?sort_by={sort_field}')
Expand All @@ -214,7 +216,7 @@ def test_schema_sort_by_name(create_schema, client, MOD_engine_cache):
for comparison_tuple in comparison_tuples:
check_schema_response(
MOD_engine_cache, comparison_tuple[0], comparison_tuple[1], comparison_tuple[1].name,
comparison_tuple[1].database.name
comparison_tuple[1].database.name, comparison_tuple[1].database.id
)


Expand Down Expand Up @@ -250,7 +252,7 @@ def test_schema_sort_by_id(create_schema, client, MOD_engine_cache):
check_schema_response(
MOD_engine_cache,
comparison_tuple[0], comparison_tuple[1], comparison_tuple[1].name,
comparison_tuple[1].database.name
comparison_tuple[1].database.name, comparison_tuple[1].database.id
)

response = client.get('/api/db/v0/schemas/?sort_by=id')
Expand All @@ -261,20 +263,20 @@ def test_schema_sort_by_id(create_schema, client, MOD_engine_cache):
check_schema_response(
MOD_engine_cache,
comparison_tuple[0], comparison_tuple[1], comparison_tuple[1].name,
comparison_tuple[1].database.name
comparison_tuple[1].database.name, comparison_tuple[1].database.id
)


def test_schema_create_by_superuser(client, FUN_create_dj_db, MOD_engine_cache):
db_name = "some_db1"
FUN_create_dj_db(db_name)
database = FUN_create_dj_db(db_name)

schema_count_before = Schema.objects.count()

schema_name = 'Test Schema'
data = {
'name': schema_name,
'database': db_name
'connection_id': database.id
}
response = client.post('/api/db/v0/schemas/', data=data)
response_schema = response.json()
Expand All @@ -289,19 +291,20 @@ def test_schema_create_by_superuser(client, FUN_create_dj_db, MOD_engine_cache):
schema,
schema_name,
db_name,
database.id,
check_schema_objects=0
)


def test_schema_create_by_superuser_too_long_name(client, FUN_create_dj_db):
db_name = "some_db1"
FUN_create_dj_db(db_name)
database = FUN_create_dj_db(db_name)
schema_count_before = Schema.objects.count()
very_long_string = ''.join(map(str, range(50)))
schema_name = 'very_long_identifier_' + very_long_string
data = {
'name': schema_name,
'database': db_name
'connection_id': database.id
}
response = client.post('/api/db/v0/schemas/', data=data)
assert response.status_code == 400
Expand All @@ -318,7 +321,7 @@ def test_schema_create_by_db_manager(client_bob, user_bob, FUN_create_dj_db, get
schema_name = 'Test Schema'
data = {
'name': schema_name,
'database': db_name
'connection_id': database.id
}
response = client_bob.post('/api/db/v0/schemas/', data=data)
assert response.status_code == 400
Expand All @@ -337,7 +340,7 @@ def test_schema_create_by_db_editor(client_bob, user_bob, FUN_create_dj_db, get_
schema_name = 'Test Schema'
data = {
'name': schema_name,
'database': db_name
'connection_id': database.id
}
response = client_bob.post('/api/db/v0/schemas/', data=data)
assert response.status_code == 400
Expand All @@ -353,27 +356,27 @@ def test_schema_create_multiple_existing_roles(client_bob, user_bob, FUN_create_
schema_name = 'Test Schema'
data = {
'name': schema_name,
'database': database_with_viewer_access.name
'connection_id': database_with_viewer_access.id
}
response = client_bob.post('/api/db/v0/schemas/', data=data)
assert response.status_code == 400
data['database'] = database_with_manager_access.name
data['connection_id'] = database_with_manager_access.id
response = client_bob.post('/api/db/v0/schemas/', data=data)
assert response.status_code == 201


@pytest.mark.skip("Faulty DB handling assumptions; invalid")
def test_schema_create_description(client, FUN_create_dj_db, MOD_engine_cache):
db_name = "some_db2"
FUN_create_dj_db(db_name)
database = FUN_create_dj_db(db_name)

schema_count_before = Schema.objects.count()

schema_name = 'Test Schema with description'
description = 'blah blah blah'
data = {
'name': schema_name,
'database': db_name,
'connection_id': database.id,
'description': description,
}
response = client.post('/api/db/v0/schemas/', data=data)
Expand Down Expand Up @@ -415,7 +418,14 @@ def test_schema_partial_update(create_schema, client, test_db_name, MOD_engine_c

response_schema = response.json()
assert response.status_code == 200
check_schema_response(MOD_engine_cache, response_schema, schema, new_schema_name, test_db_name)
check_schema_response(
MOD_engine_cache,
response_schema,
schema,
new_schema_name,
test_db_name,
schema.database.id
)

schema = Schema.objects.get(oid=schema.oid)
assert schema.name == new_schema_name
Expand Down
7 changes: 4 additions & 3 deletions mathesar/utils/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,12 @@
from mathesar.models.base import Schema, Database


def create_schema_and_object(name, database, comment=None):
def create_schema_and_object(name, connection_id, comment=None):
try:
database_model = Database.objects.get(name=database)
database_model = Database.objects.get(id=connection_id)
database_name = database_model.name
except ObjectDoesNotExist:
raise ValidationError({"database": f"Database '{database}' not found"})
raise ValidationError({"database": f"Database '{database_name}' not found"})

engine = create_mathesar_engine(database_model)

Expand Down

0 comments on commit fa4bcb0

Please sign in to comment.