Skip to content

Commit

Permalink
tests: fix and optimize tests
Browse files Browse the repository at this point in the history
  • Loading branch information
CyanVoxel committed Nov 27, 2024
1 parent 5d60797 commit 001019f
Showing 1 changed file with 97 additions and 91 deletions.
188 changes: 97 additions & 91 deletions tagstudio/src/qt/widgets/migration_modal.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,11 +553,19 @@ def sanitize_json_field(value):
for x in json_entry.fields
]
json_fields.sort()
if not (
(json_fields == sql_fields)
and json_fields is not None
and sql_fields is not None
):
return False

logger.info(
"[Field Comparison]",
fields="\n".join([str(x) for x in zip(json_fields, sql_fields)]),
)

logger.info("Field Comparison:")
logger.info("\n".join([str(x) for x in zip(json_fields, sql_fields)]))

return (json_fields == sql_fields) and json_fields is not None and sql_fields is not None
return True

def check_path_parity(self) -> bool:
"""Check if all JSON file paths match the new SQL paths."""
Expand All @@ -568,117 +576,115 @@ def check_path_parity(self) -> bool:

def check_subtag_parity(self) -> bool:
"""Check if all JSON subtags match the new SQL subtags."""
sql_subtags: set[int] = None
json_subtags: set[int] = None

with Session(self.sql_lib.engine) as session:
sql_subtags: list[tuple[int, set[int]]] = []
json_subtags: list[tuple[int, set[int]]] = []

for sql_tag in self.sql_lib.tags:
subtags = (
sql_tag.id,
set(
session.scalars(
select(TagSubtag.child_id).where(TagSubtag.parent_id == sql_tag.id)
)
),
for tag in self.sql_lib.tags:
tag_id = tag.id # Tag IDs start at 0
sql_subtags = set(
session.scalars(select(TagSubtag.child_id).where(TagSubtag.parent_id == tag.id))
)
# JSON tags allowed self-parenting; SQL tags no longer allow this.
json_subtags = set(self.json_lib.get_tag(tag_id).subtag_ids).difference(
set([self.json_lib.get_tag(tag_id).id])
)
sql_subtags.append(subtags)
logger.info(sql_subtags[-1])
sql_subtags.sort()

for json_tag in self.json_lib.tags:
json_subtags.append(
(json_tag.id, set(json_tag.subtag_ids).difference(set([json_tag.id])))
logger.info(
"[Subtag Parity]",
tag_id=tag_id,
json_subtags=json_subtags,
sql_subtags=sql_subtags,
)
logger.info(json_subtags[-1])
json_subtags.sort()

logger.info("Subtag Comparison:")
logger.info("\n".join([str(x) for x in zip(json_subtags, sql_subtags)]))
if not (
(sql_subtags == json_subtags)
and sql_subtags is not None
and json_subtags is not None
):
return False

return (
(sql_subtags == json_subtags) and sql_subtags is not None and json_subtags is not None
)
return True

def check_ext_type(self) -> bool:
return self.json_lib.is_exclude_list == self.sql_lib.prefs(LibraryPrefs.IS_EXCLUDE_LIST)

def check_alias_parity(self) -> bool:
"""Check if all JSON aliases match the new SQL aliases."""
sql_aliases: set[str] = None
json_aliases: set[str] = None

with Session(self.sql_lib.engine) as session:
sql_aliases: list[tuple[int, set[str]]] = []
json_aliases: list[tuple[int, set[str]]] = []

for sql_tag in self.sql_lib.tags:
aliases = (
sql_tag.id,
set(
session.scalars(select(TagAlias.name).where(TagAlias.tag_id == sql_tag.id))
),
for tag in self.sql_lib.tags:
tag_id = tag.id # Tag IDs start at 0
sql_aliases = set(
session.scalars(select(TagAlias.name).where(TagAlias.tag_id == tag.id))
)
sql_aliases.append(aliases)
sql_aliases.sort()
json_aliases = set(self.json_lib.get_tag(tag_id).aliases)

for json_tag in self.json_lib.tags:
json_aliases.append((json_tag.id, set(json_tag.aliases)))
json_aliases.sort()

logger.info("Alias Comparison:")
logger.info("\n".join([str(x) for x in zip(json_aliases, sql_aliases)]))
logger.info(
"[Alias Parity]",
tag_id=tag_id,
json_aliases=json_aliases,
sql_aliases=sql_aliases,
)
if not (
(sql_aliases == json_aliases)
and sql_aliases is not None
and json_aliases is not None
):
return False

return (
(sql_aliases == json_aliases) and sql_aliases is not None and json_aliases is not None
)
return True

def check_shorthand_parity(self) -> bool:
"""Check if all JSON shorthands match the new SQL shorthands."""
with Session(self.sql_lib.engine) as session:
sql_shorthands: list[tuple[int, set[str]]] = []
json_shorthands: list[tuple[int, set[str]]] = []

for sql_tag in self.sql_lib.tags:
shorthands = (
sql_tag.id,
set(
session.scalars(select(TagAlias.name).where(TagAlias.tag_id == sql_tag.id))
),
)
sql_shorthands.append(shorthands)
sql_shorthands.sort()
sql_shorthand: str = None
json_shorthand: str = None

for tag in self.sql_lib.tags:
tag_id = tag.id # Tag IDs start at 0
sql_shorthand = tag.shorthand
json_shorthand = self.json_lib.get_tag(tag_id).shorthand

logger.info(
"[Shorthand Parity]",
tag_id=tag_id,
json_shorthand=json_shorthand,
sql_shorthand=sql_shorthand,
)

for json_tag in self.json_lib.tags:
json_shorthands.append((json_tag.id, set(json_tag.aliases)))
json_shorthands.sort()
if not (
(sql_shorthand == json_shorthand)
and sql_shorthand is not None
and json_shorthand is not None
):
return False

logger.info("Shorthand Comparison:")
logger.info("\n".join([str(x) for x in zip(json_shorthands, sql_shorthands)]))

return (
(sql_shorthands == json_shorthands)
and sql_shorthands is not None
and json_shorthands is not None
)
return True

def check_color_parity(self) -> bool:
"""Check if all JSON tag colors match the new SQL tag colors."""
sql_colors: list[tuple[int, str]] = []
json_colors: list[tuple[int, str]] = []

for sql_tag in self.sql_lib.tags:
sql_colors.append((sql_tag.id, (sql_tag.color.name)))
sql_colors.sort()

for json_tag in self.json_lib.tags:
json_colors.append(
(
json_tag.id,
json_tag.color.upper().replace(" ", "_")
if json_tag.color != ""
else TagColor.DEFAULT.name,
)
sql_color: str = None
json_color: str = None

for tag in self.sql_lib.tags:
tag_id = tag.id # Tag IDs start at 0
sql_color = tag.color.name
json_color = (
self.json_lib.get_tag(tag_id).color.upper().replace(" ", "_")
if self.json_lib.get_tag(tag_id).color != ""
else TagColor.DEFAULT.name
)

logger.info(
"[Color Parity]",
tag_id=tag_id,
json_shorthand=json_color,
sql_shorthand=sql_color,
)
json_colors.sort()

logger.info("Color Comparison:")
logger.info("\n".join([str(x) for x in zip(json_colors, sql_colors)]))
if not ((sql_color == json_color) and sql_color is not None and json_color is not None):
return False

return (sql_colors == json_colors) and sql_colors is not None and json_colors is not None
return True

0 comments on commit 001019f

Please sign in to comment.