Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add invert to FFI and Python #277

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions rustfst-ffi/src/algorithms/inversion.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
use anyhow::anyhow;

use crate::fst::CFst;
use crate::{get_mut, wrap, RUSTFST_FFI_RESULT};

use rustfst::algorithms::invert;
use rustfst::fst_impls::VectorFst;

/// # Safety
///
/// The pointers should be valid.
#[no_mangle]
pub unsafe extern "C" fn fst_invert(ptr: *mut CFst) -> RUSTFST_FFI_RESULT {
wrap(|| {
let fst = get_mut!(CFst, ptr);
let vec_fst: &mut VectorFst<_> = fst
.downcast_mut()
.ok_or_else(|| anyhow!("Could not downcast to vector FST"))?;
invert(vec_fst);
Ok(())
})
}
1 change: 1 addition & 0 deletions rustfst-ffi/src/algorithms/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ pub mod compose;
pub mod concat;
pub mod connect;
pub mod determinize;
pub mod inversion;
pub mod isomorphic;
mod minimize;
pub mod optimize;
Expand Down
1 change: 1 addition & 0 deletions rustfst-python/docs/rustfst/algorithms/inversion/index.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
::: rustfst.algorithms.inversion
2 changes: 2 additions & 0 deletions rustfst-python/mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ nav:
- rustfst/algorithms/rm_epsilon/index.md
- reverse:
- rustfst/algorithms/reverse/index.md
- inversion:
- rustfst/algorithms/inversion/index.md
- project:
- rustfst/algorithms/project/index.md
- randgen:
Expand Down
25 changes: 25 additions & 0 deletions rustfst-python/rustfst/algorithms/inversion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from __future__ import annotations
from rustfst.ffi_utils import (
lib,
check_ffi_error,
)

from rustfst.fst.vector_fst import VectorFst


def invert(fst: VectorFst) -> VectorFst:
"""
Invert the transduction corresponding to an FST by exchanging the
FST's input and output labels in-place.

Args:
fst: FST to be inverted.
Returns:
self
"""

ret_code = lib.fst_invert(fst.ptr)
err_msg = "Error during invert"
check_ffi_error(ret_code, err_msg)

return fst
14 changes: 9 additions & 5 deletions rustfst-python/rustfst/algorithms/tr_sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,15 @@


def tr_sort(fst: VectorFst, ilabel_cmp: bool):
"""
tr_sort(fst)
sort fst trs according to their ilabel or olabel
:param fst: Fst
:param ilabel_cmp: bool
"""Sort fst trs in place according to their input or output label.

This is often necessary for composition to work properly. It
corresponds to `ArcSort` in OpenFST.

Args:
fst: FST to be tr-sorted
ilabel_cmp: Sort on input labels if `True`, output labels
if `False`.
"""

ret_code = lib.fst_tr_sort(fst.ptr, ctypes.c_bool(ilabel_cmp))
Expand Down
12 changes: 12 additions & 0 deletions rustfst-python/rustfst/fst/vector_fst.py
Original file line number Diff line number Diff line change
Expand Up @@ -723,6 +723,18 @@ def isomorphic(self, other: VectorFst) -> bool:

return isomorphic(self, other)

def invert(self) -> VectorFst:
"""
Invert the transduction corresponding to an FST by exchanging the
FST's input and output labels in-place.

Returns:
self
"""
from rustfst.algorithms.inversion import invert

return invert(self)

def __add__(self, other: VectorFst) -> VectorFst:
"""
`fst_1 + fst_2` is a shortcut to perform the concatenation of `fst_1` and `fst_2`.
Expand Down
51 changes: 51 additions & 0 deletions rustfst-python/tests/algorithms/test_inversion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from rustfst import VectorFst, Tr


def test_invert():
# FST 1
fst1 = VectorFst()

s1 = fst1.add_state()
s2 = fst1.add_state()
s3 = fst1.add_state()

fst1.set_start(s1)
fst1.set_final(s3, 1.0)

tr1_1 = Tr(1, 2, 1.0, s2)
fst1.add_tr(s1, tr1_1)

tr1_2 = Tr(3, 4, 2.0, s2)
fst1.add_tr(s1, tr1_2)

tr1_3 = Tr(5, 6, 1.5, s2)
fst1.add_tr(s2, tr1_3)

tr1_4 = Tr(3, 5, 1.0, s3)
fst1.add_tr(s2, tr1_4)

fst1 = fst1.invert()

# Expected FST
expected_fst = VectorFst()

s1 = expected_fst.add_state()
s2 = expected_fst.add_state()
s3 = expected_fst.add_state()

expected_fst.set_start(s1)
expected_fst.set_final(s3, 1.0)

tr1_1 = Tr(2, 1, 1.0, s2)
expected_fst.add_tr(s1, tr1_1)

tr1_2 = Tr(4, 3, 2.0, s2)
expected_fst.add_tr(s1, tr1_2)

tr1_3 = Tr(6, 5, 1.5, s2)
expected_fst.add_tr(s2, tr1_3)

tr1_4 = Tr(5, 3, 1.0, s3)
expected_fst.add_tr(s2, tr1_4)

assert expected_fst == fst1