Skip to content

Commit

Permalink
Ne plus utiliser les RNN (#29)
Browse files Browse the repository at this point in the history
* fix: effacer tout ce qui est RNN, ça marche pas, bon débarras

* feat: restaurer la tokenisation car cest utile
  • Loading branch information
dhdaines authored Jul 25, 2024
1 parent 8b97dca commit cbb0c0e
Show file tree
Hide file tree
Showing 9 changed files with 20 additions and 3,576 deletions.
9 changes: 3 additions & 6 deletions alexi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from .label import Identificateur
from .search import search
from .segment import DEFAULT_MODEL as DEFAULT_SEGMENT_MODEL
from .segment import RNNSegmenteur, Segmenteur
from .segment import Segmenteur

LOGGER = logging.getLogger("alexi")
VERSION = "0.4.0"
Expand Down Expand Up @@ -60,10 +60,7 @@ def convert_main(args: argparse.Namespace):
def segment_main(args: argparse.Namespace):
"""Segmenter un CSV"""
crf: Segmenteur
if args.model.suffix == ".pt":
crf = RNNSegmenteur(args.model)
else:
crf = Segmenteur(args.model)
crf = Segmenteur(args.model)
reader = csv.DictReader(args.csv)
write_csv(crf(reader), sys.stdout)

Expand Down Expand Up @@ -144,7 +141,7 @@ def make_argparse() -> argparse.ArgumentParser:
"segment", help="Segmenter et étiquetter les segments d'un CSV"
)
segment.add_argument(
"--model", help="Modele CRF ou RNN", type=Path, default=DEFAULT_SEGMENT_MODEL
"--model", help="Modele CRF", type=Path, default=DEFAULT_SEGMENT_MODEL
)
segment.add_argument(
"csv",
Expand Down
9 changes: 3 additions & 6 deletions alexi/extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from alexi.label import Identificateur
from alexi.link import Resolver
from alexi.segment import DEFAULT_MODEL as DEFAULT_SEGMENT_MODEL
from alexi.segment import DEFAULT_MODEL_NOSTRUCT, RNNSegmenteur, Segmenteur
from alexi.segment import DEFAULT_MODEL_NOSTRUCT, Segmenteur
from alexi.types import T_obj

LOGGER = logging.getLogger("extract")
Expand All @@ -39,7 +39,7 @@ def add_arguments(parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
help="Ne pas utiliser le CSV de référence",
action="store_true",
)
parser.add_argument("--segment-model", help="Modele CRF/RNN", type=Path)
parser.add_argument("--segment-model", help="Modele CRF", type=Path)
parser.add_argument(
"--label-model", help="Modele CRF", type=Path, default=DEFAULT_LABEL_MODEL
)
Expand Down Expand Up @@ -342,10 +342,7 @@ def __init__(
self.outdir = outdir
self.crf_s = Identificateur()
if segment_model is not None:
if segment_model.suffix == ".pt":
self.crf = RNNSegmenteur(segment_model)
else:
self.crf = Segmenteur(segment_model)
self.crf = Segmenteur(segment_model)
self.crf_n = None
else:
self.crf = Segmenteur(DEFAULT_SEGMENT_MODEL)
Expand Down
Loading

0 comments on commit cbb0c0e

Please sign in to comment.