diff --git a/alphabet/alphabet.go b/alphabet/alphabet.go index 7355a83f..6eab6435 100644 --- a/alphabet/alphabet.go +++ b/alphabet/alphabet.go @@ -8,7 +8,7 @@ import "fmt" // Alphabet is a struct that holds a list of symbols and a map of symbols to their index in the list. type Alphabet struct { symbols []string - encoding map[interface{}]int + encoding map[interface{}]uint8 } // Error is an error type that is returned when a symbol is not in the alphabet. @@ -23,23 +23,37 @@ func (e *Error) Error() string { // NewAlphabet creates a new alphabet from a list of symbols. func NewAlphabet(symbols []string) *Alphabet { - encoding := make(map[interface{}]int) + encoding := make(map[interface{}]uint8) for index, symbol := range symbols { - encoding[symbol] = index - encoding[index] = index + encoding[symbol] = uint8(index) + encoding[index] = uint8(index) } return &Alphabet{symbols, encoding} } // Encode returns the index of a symbol in the alphabet. -func (alphabet *Alphabet) Encode(symbol interface{}) (int, error) { +func (alphabet *Alphabet) Encode(symbol interface{}) (uint8, error) { c, ok := alphabet.encoding[symbol] if !ok { - return 0, &Error{fmt.Sprintf("Symbol %v not in alphabet", symbol)} + return 0, fmt.Errorf("Symbol %v not in alphabet", symbol) } return c, nil } +// TODO: compress more when len(symbols) << 2^8 +// TODO: DecodeAll +func (alphabet *Alphabet) EncodeAll(seq string) ([]uint8, error) { + encoded := make([]uint8, len(seq)) + for i, r := range seq { + encoding, err := alphabet.Encode(string(r)) + if err != nil { + return nil, fmt.Errorf("Symbol %c in position %d not in alphabet", r, i) + } + encoded[i] = uint8(encoding) + } + return encoded, nil +} + // Decode returns the symbol at a given index in the alphabet. func (alphabet *Alphabet) Decode(code interface{}) (string, error) { c, ok := code.(int) diff --git a/alphabet/alphabet_test.go b/alphabet/alphabet_test.go index 81dd70ad..dc8ea04f 100644 --- a/alphabet/alphabet_test.go +++ b/alphabet/alphabet_test.go @@ -16,7 +16,7 @@ func TestAlphabet(t *testing.T) { if err != nil { t.Errorf("Unexpected error encoding symbol %s: %v", symbol, err) } - if code != i { + if int(code) != i { t.Errorf("Incorrect encoding of symbol %s: expected %d, got %d", symbol, i, code) } } @@ -48,7 +48,7 @@ func TestAlphabet(t *testing.T) { if err != nil { t.Errorf("Unexpected error encoding symbol %s: %v", symbol, err) } - if code != i { + if int(code) != i { t.Errorf("Incorrect encoding of symbol %s: expected %d, got %d", symbol, i, code) } } @@ -57,7 +57,7 @@ func TestAlphabet(t *testing.T) { if err != nil { t.Errorf("Unexpected error encoding symbol %s: %v", symbol, err) } - if code != i+len(symbols) { + if int(code) != i+len(symbols) { t.Errorf("Incorrect encoding of symbol %s: expected %d, got %d", symbol, i+len(symbols), code) } } diff --git a/alphabet/kmer_counter.go b/alphabet/kmer_counter.go new file mode 100644 index 00000000..e1dc4208 --- /dev/null +++ b/alphabet/kmer_counter.go @@ -0,0 +1,159 @@ +// TODO: add Alphabet and KmerCounter for codons +// TODO: create benchmark for space and time costs +// TODO: create efficient storage for children + counts in top node via multidimensional array +// TODO: integrate with IO parsers +// TODO: enable reading and writing to binary format +// TODO: iterate over observed kmers (what is the appropriate iteration method?) +// TODO: comparison between two KmerCounter's +// TODO: initialize upper levels while respecting cache locality +// TODO: add counts from other KmerCounter, to enable reducing parallel counts + +package alphabet + +import ( + "fmt" +) + +type KmerCounter struct { + alphabet *Alphabet + num_symbols uint8 + max_k uint8 + total uint64 + children []Node +} + +type Node struct { + succ *Node + child *Node + encoding uint8 + count uint32 +} + +func NewKmerCounter(a *Alphabet, max_k uint8) *KmerCounter { + kc := new(KmerCounter) + kc.alphabet = a + kc.num_symbols = uint8(len(a.symbols)) + kc.max_k = max_k + + kc.children = make([]Node, kc.num_symbols) + for i := uint8(1); i < kc.num_symbols; i++ { + kc.children[i].encoding = i + kc.children[i-1].succ = &kc.children[i] + } + return kc +} + +func lookupChild(n *Node, encoding uint8) *Node { + for c := n.child; c != nil && encoding <= c.encoding; c = c.succ { + if encoding == c.encoding { return c } + } + return nil +} + +func insertChild(p *Node, encoding uint8, n *Node) { + if p.child == nil { + p.child = n + } else if (n.encoding < p.child.encoding) { + n.succ = p.child + p.child = n + } else { + for c := p.child; c.encoding < encoding; c = c.succ { + if encoding < c.succ.encoding { + n.succ = c.succ + c.succ = n + } + } + } + n.count++ // not sure why this is needed + n.encoding = encoding +} + +func Observe(kc *KmerCounter, seq string) error { + cbuf, err := kc.alphabet.EncodeAll(seq[:kc.max_k]) + if err != nil { return err } + + CreateChildren := func (index int, remaining int) *Node { + if remaining == 0 { return nil } // this condition should never happen + nodes := make([]Node, remaining) + + for j, n := range nodes { + n.count = 1 + if j != 0 { + nodes[j-1].child = &n + n.encoding = cbuf[(index+j) % int(kc.max_k)] + } + } + + return &nodes[0] + } + + UpdateBuffer := func (index int) (max_k int, err error) { + max_k = int(kc.max_k) + lookahead := index + max_k - 1 + var encoding uint8 + if lookahead < len(seq) { + next := string(seq[lookahead]) + encoding, err = kc.alphabet.Encode(next) + if err != nil { + err = fmt.Errorf("in position %d: %w", index, err) + return + } + cbuf[lookahead % int(kc.max_k)] = encoding + } else { + max_k = len(seq) - index + } + return + } + + var encoding uint8 + for i, _ := range seq { + max_k, err := UpdateBuffer(i) + if err != nil { return err } + + p := &kc.children[cbuf[i % int(kc.max_k)]] + kc.total++ + for k := 1; k <= max_k; k++ { + p.count++ + // fmt.Printf("%d %d %v\n", i, k, p) + + if k != max_k { + encoding = cbuf[(i+k) % int(kc.max_k)] + c := lookupChild(p, encoding) + if c == nil { + insertChild(p, encoding, CreateChildren(i+k, max_k-k)) + break // inserted nodes already have count = 1 added + } + p = c + } + } + } + return nil +} + +func LookupCount(kc *KmerCounter, kmer string) (count uint32, err error) { + if len(kmer) > int(kc.max_k) { + err = fmt.Errorf("kmer_counter: attempted to lookup count of %d-mer which exceeds max supported length %d", len(kmer), kc.max_k) + return + } + if len(kmer) == 0 { + err = fmt.Errorf("kmer_counter: attempted to lookup count of 0mer") + return + } + encoded, err := kc.alphabet.EncodeAll(kmer) + if err != nil { return } + + var k int + var p *Node + for k, p = 1, &kc.children[encoded[0]]; k < len(kmer) && p != nil; k++ { + p = lookupChild(p, encoded[k]) + } + if p == nil { return } + count = p.count + return +} + +func LookupFrequency(kc *KmerCounter, kmer string) (float64, error) { + count, err := LookupCount(kc, kmer) + if err != nil { return 0, err } + return float64(count) / float64(kc.total - uint64(len(kmer)-1)), nil +} diff --git a/alphabet/kmer_counter_test.go b/alphabet/kmer_counter_test.go new file mode 100644 index 00000000..f59e88bc --- /dev/null +++ b/alphabet/kmer_counter_test.go @@ -0,0 +1,67 @@ +package alphabet + +import ( + "strings" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/TimothyStiles/poly/transform/variants" +) + +func Assertf (t *testing.T, cond bool, format string, args ...any) { + if !cond { + t.Errorf(format, args...) + } +} + +func TestKmerCounterEmpty(t *testing.T) { + kc := NewKmerCounter(DNA, 3) + + count, err := LookupCount(kc, "A") + assert.EqualValues(t, 0, count) + assert.NoError(t, err) + count, err = LookupCount(kc, "X") + assert.Error(t, err) +} + +func TestKmerCounterRepeated(t *testing.T) { + kc := NewKmerCounter(DNA, 3) + seq := strings.Repeat("A", 12) + assert.Equal(t, 12, len(seq)) + err := Observe(kc, seq) + assert.NoError(t, err) + + assert.Equal(t, 12, int(kc.total)) + + onemers, _ := variants.AllVariantsIUPAC("N") + for _, kmer := range onemers { + count, err := LookupCount(kc, kmer) + assert.NoError(t, err) + if kmer == "A" { + assert.EqualValues(t, 12, count, "Wrong count") + } else { + assert.EqualValues(t, 0, count, "Reports nonzero count for 1mer not present in sequence %v", kmer) + } + } + twomers, _ := variants.AllVariantsIUPAC("NN") + for _, kmer := range twomers { + count, err := LookupCount(kc, kmer) + assert.NoError(t, err) + if kmer == "AA" { + assert.EqualValues(t, 11, count, "Wrong count") + } else { + assert.EqualValues(t, 0, count, "Reports nonzero count for 2mer not present in sequence %v", kmer) + } + } + threemers, _ := variants.AllVariantsIUPAC("NNN") + for _, kmer := range threemers { + count, err := LookupCount(kc, kmer) + assert.NoError(t, err) + if kmer == "AAA" { + assert.EqualValues(t, 10, count, "Wrong count") + } else { + assert.EqualValues(t, 0, count, "Reports nonzero count for 3mer not present in sequence %v", kmer) + } + } +}