diff --git a/datacontract/imports/avro_importer.py b/datacontract/imports/avro_importer.py index e219d00d..49bcf1a6 100644 --- a/datacontract/imports/avro_importer.py +++ b/datacontract/imports/avro_importer.py @@ -1,6 +1,7 @@ import avro.schema -from datacontract.model.data_contract_specification import DataContractSpecification, Model, Field +from datacontract.model.data_contract_specification import \ + DataContractSpecification, Model, Field from datacontract.model.exceptions import DataContractException @@ -56,6 +57,9 @@ def import_record_fields(record_fields): imported_fields[field.name].type = type if type == "record": imported_fields[field.name].fields = import_record_fields(get_record_from_union_field(field).fields) + elif type == "array": + imported_fields[field.name].type = "array" + imported_fields[field.name].items = import_avro_array_items(get_array_from_union_field(field)) elif field.type.type == "array": imported_fields[field.name].type = "array" imported_fields[field.name].items = import_avro_array_items(field.type) @@ -102,6 +106,13 @@ def get_record_from_union_field(field): return None +def get_array_from_union_field(field): + for field_type in field.type.schemas: + if field_type.type == "array": + return field_type + return None + + def map_type_from_avro(avro_type_str: str): # TODO: ambiguous mapping in the export if avro_type_str == "null": diff --git a/tests/fixtures/avro/data/arrays.avsc b/tests/fixtures/avro/data/arrays.avsc index 2f654711..35af71f0 100644 --- a/tests/fixtures/avro/data/arrays.avsc +++ b/tests/fixtures/avro/data/arrays.avsc @@ -39,6 +39,23 @@ "items": "int" } } + }, + { + "name": "nationalities", + "type": [ + "null", + { + "type": "array", + "items": { + "type": "string", + "connect.parameters": { + "avro.java.string": "String" + }, + "avro.java.string": "String" + } + } + ], + "default": null } ], "name": "orders", diff --git a/tests/test_import_avro.py b/tests/test_import_avro.py index 1200374d..785034fb 100644 --- a/tests/test_import_avro.py +++ b/tests/test_import_avro.py @@ -122,6 +122,11 @@ def test_import_avro_arrays_of_records_and_nested_arrays(): type: array items: type: int + nationalities: + type: array + required: false + items: + type: string """ print("Result:\n", result.to_yaml()) assert yaml.safe_load(result.to_yaml()) == yaml.safe_load(expected) @@ -196,6 +201,8 @@ def test_import_avro_nested_records_with_arrays(): ApplicableBranchIDs: type: array required: false + items: + type: string ProductGroupDetails: type: record required: false @@ -206,6 +213,15 @@ def test_import_avro_nested_records_with_arrays(): ItemList: type: array required: false + items: + type: object + fields: + ProductID: + type: string + required: true + IsPromoItem: + type: boolean + required: false """ print("Result:\n", result.to_yaml()) assert yaml.safe_load(result.to_yaml()) == yaml.safe_load(expected)