Skip to content

Commit

Permalink
Add typing for all lists
Browse files Browse the repository at this point in the history
  • Loading branch information
daveisfera committed Nov 8, 2024
1 parent 7a4ac7f commit 2e3650a
Showing 1 changed file with 12 additions and 8 deletions.
20 changes: 12 additions & 8 deletions m3u8/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# license that can be found in the LICENSE file.
import decimal
import os
from typing import TypeVar

from m3u8.mixins import BasePathMixin, GroupedBasePathMixin
from m3u8.parser import format_date_time, parse
Expand Down Expand Up @@ -449,7 +450,10 @@ def _create_sub_directories(self, filename):
os.makedirs(basename, exist_ok=True)


class TagList(list):
T = TypeVar("T")


class TagList(list[T]):
def __str__(self):
output = [str(tag) for tag in self]
return "\n".join(output)
Expand Down Expand Up @@ -712,7 +716,7 @@ def base_uri(self, newbase_uri):
self.init_section.base_uri = newbase_uri


class SegmentList(list, GroupedBasePathMixin):
class SegmentList(list[Segment], GroupedBasePathMixin):
def dumps(self, timespec="milliseconds", infspec="auto"):
output = []
last_segment = None
Expand Down Expand Up @@ -827,7 +831,7 @@ def __str__(self):
return self.dumps(None)


class PartialSegmentList(list, GroupedBasePathMixin):
class PartialSegmentList(list[PartialSegment], GroupedBasePathMixin):
def __str__(self):
output = [str(part) for part in self]
return "\n".join(output)
Expand Down Expand Up @@ -1015,7 +1019,7 @@ def __str__(self):
return "#EXT-X-STREAM-INF:" + ",".join(stream_inf) + "\n" + self.uri


class PlaylistList(TagList, GroupedBasePathMixin):
class PlaylistList(TagList[Playlist], GroupedBasePathMixin):
pass


Expand Down Expand Up @@ -1270,7 +1274,7 @@ def __str__(self):
return self.dumps()


class MediaList(TagList, GroupedBasePathMixin):
class MediaList(TagList[Media], GroupedBasePathMixin):
@property
def uri(self):
return [media.uri for media in self]
Expand Down Expand Up @@ -1310,7 +1314,7 @@ def __str__(self):
return self.dumps()


class RenditionReportList(list, GroupedBasePathMixin):
class RenditionReportList(list[RenditionReport], GroupedBasePathMixin):
def __str__(self):
output = [str(report) for report in self]
return "\n".join(output)
Expand Down Expand Up @@ -1439,7 +1443,7 @@ def __str__(self):
return self.dumps()


class SessionDataList(TagList):
class SessionDataList(TagList[SessionData]):
pass


Expand Down Expand Up @@ -1498,7 +1502,7 @@ def __str__(self):
return self.dumps()


class DateRangeList(TagList):
class DateRangeList(TagList[DateRange]):
pass


Expand Down

0 comments on commit 2e3650a

Please sign in to comment.