diff --git a/.travis.yml b/.travis.yml index 0cf5b15..b768036 100644 --- a/.travis.yml +++ b/.travis.yml @@ -7,7 +7,7 @@ os: language: go go: - - 1.9 +- tip # The latest version of Go. install: - go get -d -t -v ./... diff --git a/cmd/align.go b/cmd/align.go index fe2346a..34e5403 100644 --- a/cmd/align.go +++ b/cmd/align.go @@ -32,7 +32,7 @@ import ( "github.com/pkg/profile" "github.com/spf13/cobra" "github.com/will-rowe/groot/src/graph" - "github.com/will-rowe/groot/src/lshIndex" + "github.com/will-rowe/groot/src/lshForest" "github.com/will-rowe/groot/src/misc" "github.com/will-rowe/groot/src/stream" "github.com/will-rowe/groot/src/version" @@ -160,8 +160,8 @@ func alignParamCheck() error { func runAlign() { // set up profiling if *profiling == true { - defer profile.Start(profile.MemProfile, profile.ProfilePath("./")).Stop() - //defer profile.Start(profile.ProfilePath("./")).Stop() + //defer profile.Start(profile.MemProfile, profile.ProfilePath("./")).Stop() + defer profile.Start(profile.ProfilePath("./")).Stop() } // start logging logFH := misc.StartLogging(*logFile) @@ -187,29 +187,21 @@ func runAlign() { log.Print("loading index information...") info := new(misc.IndexInfo) misc.ErrorCheck(info.Load(*indexDir + "/index.info")) - if info.Containment { - log.Printf("\tindex type: lshEnsemble") - log.Printf("\tcontainment search seeding: enabled") - } else { - log.Printf("\tindex type: lshForest") - log.Printf("\tcontainment search seeding: disabled") - } - log.Printf("\twindow sized used in indexing: %d\n", info.ReadLength) log.Printf("\tk-mer size: %d\n", info.Ksize) log.Printf("\tsignature size: %d\n", info.SigSize) log.Printf("\tJaccard similarity theshold: %0.2f\n", info.JSthresh) + log.Printf("\twindow sized used in indexing: %d\n", info.ReadLength) log.Print("loading the groot graphs...") graphStore := make(graph.GraphStore) misc.ErrorCheck(graphStore.Load(*indexDir + "/index.graph")) log.Printf("\tnumber of variation graphs: %d\n", len(graphStore)) log.Print("loading the MinHash signatures...") - var database *lshIndex.LshEnsemble - if info.Containment { - database = lshIndex.NewLSHensemble(make([]lshIndex.Partition, lshIndex.PARTITIONS), info.SigSize, lshIndex.MAXK) - } else { - database = lshIndex.NewLSHforest(info.SigSize, info.JSthresh) - } + database := lshForest.NewLSHforest(info.SigSize, info.JSthresh) misc.ErrorCheck(database.Load(*indexDir + "/index.sigs")) + database.Index() + numHF, numBucks := database.Settings() + log.Printf("\tnumber of hash functions per bucket: %d\n", numHF) + log.Printf("\tnumber of buckets: %d\n", numBucks) /////////////////////////////////////////////////////////////////////////////////////// // create SAM references from the sequences held in the graphs referenceMap, err := graphStore.GetRefs() @@ -230,12 +222,10 @@ func runAlign() { // add in the process parameters dataStream.InputFile = *fastq - fastqChecker.Containment = info.Containment fastqChecker.WindowSize = info.ReadLength dbQuerier.Db = database dbQuerier.CommandInfo = info dbQuerier.GraphStore = graphStore - dbQuerier.Threshold = info.JSthresh graphAligner.GraphStore = graphStore graphAligner.RefMap = referenceMap graphAligner.MaxClip = *clip diff --git a/cmd/index.go b/cmd/index.go index 5a7ebd9..5ac9fd8 100644 --- a/cmd/index.go +++ b/cmd/index.go @@ -35,7 +35,7 @@ import ( "github.com/spf13/cobra" "github.com/will-rowe/gfa" "github.com/will-rowe/groot/src/graph" - "github.com/will-rowe/groot/src/lshIndex" + "github.com/will-rowe/groot/src/lshForest" "github.com/will-rowe/groot/src/misc" "github.com/will-rowe/groot/src/seqio" "github.com/will-rowe/groot/src/version" @@ -51,9 +51,6 @@ var ( msaList []string // the collected MSA files outDir *string // directory to save index files and log to defaultOutDir = "./groot-index-" + string(time.Now().Format("20060102150405")) // a default dir to store the index files - containment *bool // use lshEnsemble instead of lshForest -- allows for variable read length - maxK *int // the maxK for LSH Ensemble (only active for --containment) - numPart *int // the number of partitions for LSH Ensemble (only active for --containment) ) // the index command (used by cobra) @@ -71,15 +68,12 @@ var indexCmd = &cobra.Command{ // a function to initialise the command line arguments func init() { - kSize = indexCmd.Flags().IntP("kmerSize", "k", 31, "size of k-mer") - sigSize = indexCmd.Flags().IntP("sigSize", "s", 84, "size of MinHash signature") - readLength = indexCmd.Flags().IntP("readLength", "l", 100, "max length of query reads (which will be aligned during the align subcommand)") - jsThresh = indexCmd.Flags().Float64P("jsThresh", "j", 0.99, "minimum Jaccard similarity for a seed to be recorded (note: this is used as a containment theshold when --containment set") + kSize = indexCmd.Flags().IntP("kmerSize", "k", 7, "size of k-mer") + sigSize = indexCmd.Flags().IntP("sigSize", "s", 128, "size of MinHash signature") + readLength = indexCmd.Flags().IntP("readLength", "l", 100, "length of query reads (which will be aligned during the align subcommand)") + jsThresh = indexCmd.Flags().Float64P("jsThresh", "j", 0.99, "minimum Jaccard similarity for a seed to be recorded") msaDir = indexCmd.Flags().StringP("msaDir", "i", "", "directory containing the clustered references (MSA files) - required") outDir = indexCmd.PersistentFlags().StringP("outDir", "o", defaultOutDir, "directory to save index files to") - containment = indexCmd.Flags().BoolP("containment", "c", false, "use lshEnsemble instead of lshForest (allows read length variation > 10 bases)") - maxK = indexCmd.Flags().IntP("maxK", "m", 4, "maxK in LSH Ensemble (only active with --containment)") - numPart = indexCmd.Flags().IntP("numPart", "n", 4, "num. partitions in LSH Ensemble (only active with --containment)") indexCmd.MarkFlagRequired("msaDir") RootCmd.AddCommand(indexCmd) } @@ -142,11 +136,6 @@ func runIndex() { // check the supplied files and then log some stuff log.Printf("checking parameters...") misc.ErrorCheck(indexParamCheck()) - if *containment { - log.Printf("\tindexing scheme: lshEnsemble (containment search)") - } else { - log.Printf("\tindexing scheme: lshForest") - } log.Printf("\tprocessors: %d", *proc) log.Printf("\tk-mer size: %d", *kSize) log.Printf("\tsignature size: %d", *sigSize) @@ -209,57 +198,56 @@ func runIndex() { }() /////////////////////////////////////////////////////////////////////////////////////// // collect and store the GrootGraph windows - sigStore := []*lshIndex.GraphWindow{} - lookupMap := make(lshIndex.KeyLookupMap) + var sigStore = make([]map[int]map[int][][]uint64, len(graphStore)) + for i := range sigStore { + sigStore[i] = make(map[int]map[int][][]uint64) + } // receive the signatures + var sigCount int = 0 for window := range windowChan { - // combine graphID, nodeID and offset to form a string key for signature - stringKey := fmt.Sprintf("g%dn%do%d", window.GraphID, window.Node, window.OffSet) - // convert to a graph window - gw := &lshIndex.GraphWindow{stringKey, *readLength, window.Sig} + // initialise the inner map of sigStore if graph has not been seen yet + if _, ok := sigStore[window.GraphID][window.Node]; !ok { + sigStore[window.GraphID][window.Node] = make(map[int][][]uint64) + } // store the signature for the graph:node:offset - sigStore = append(sigStore, gw) - // add a key to the lookup map - lookupMap[stringKey] = seqio.Key{GraphID: window.GraphID, Node: window.Node, OffSet: window.OffSet} + sigStore[window.GraphID][window.Node][window.OffSet] = append(sigStore[window.GraphID][window.Node][window.OffSet], window.Sig) + sigCount++ } - numSigs := len(sigStore) - log.Printf("\tnumber of signatures generated: %d\n", numSigs) - var database *lshIndex.LshEnsemble - if *containment == false { - /////////////////////////////////////////////////////////////////////////////////////// - // run LSH forest - log.Printf("running LSH Forest...\n") - database = lshIndex.NewLSHforest(*sigSize, *jsThresh) - // range over the nodes in each graph, each node will have one or more signature - for window := range lshIndex.Windows2Chan(sigStore) { - // add the signature to the lshForest - database.Add(window.Key, window.Signature, 0) + log.Printf("\tnumber of signatures generated: %d\n", sigCount) + /////////////////////////////////////////////////////////////////////////////////////// + // run LSH forest + log.Printf("running LSH forest...\n") + database := lshForest.NewLSHforest(*sigSize, *jsThresh) + // range over the nodes in each graph, each node will have one or more signature + for graphID, nodesMap := range sigStore { + // add each signature to the database + for nodeID, offsetMap := range nodesMap { + for offset, signatures := range offsetMap { + for _, signature := range signatures { + // combine graphID, nodeID and offset to form a string key for signature + stringKey := fmt.Sprintf("g%dn%do%d", graphID, nodeID, offset) + // add the key to a lookup map + key := seqio.Key{GraphID: graphID, Node: nodeID, OffSet: offset} + database.KeyLookup[stringKey] = key + // add the signature to the lshForest + database.Add(stringKey, signature) + } + } } - // print some stuff - numHF, numBucks := database.Lshes[0].Settings() - log.Printf("\tnumber of hash functions per bucket: %d\n", numHF) - log.Printf("\tnumber of buckets: %d\n", numBucks) - } else { - /////////////////////////////////////////////////////////////////////////////////////// - // run LSH ensemble (https://github.com/ekzhu/lshensemble) - log.Printf("running LSH Ensemble...\n") - database = lshIndex.BootstrapLshEnsemble(*numPart, *sigSize, *maxK, numSigs, lshIndex.Windows2Chan(sigStore)) - // print some stuff - log.Printf("\tnumber of LSH Ensemble partitions: %d\n", *numPart) - log.Printf("\tmax no. hash functions per bucket: %d\n", *maxK) } - // attach the key lookup map to the index - database.KeyLookup = lookupMap + numHF, numBucks := database.Settings() + log.Printf("\tnumber of hash functions per bucket: %d\n", numHF) + log.Printf("\tnumber of buckets: %d\n", numBucks) /////////////////////////////////////////////////////////////////////////////////////// // record runtime info - info := &misc.IndexInfo{Version: version.VERSION, Ksize: *kSize, SigSize: *sigSize, JSthresh: *jsThresh, ReadLength: *readLength, Containment: *containment} + info := &misc.IndexInfo{Version: version.VERSION, Ksize: *kSize, SigSize: *sigSize, JSthresh: *jsThresh, ReadLength: *readLength} // save the index files log.Printf("saving index files to \"%v\"...", *outDir) - misc.ErrorCheck(database.Dump(*outDir + "/index.sigs")) - log.Printf("\tsaved MinHash signatures") misc.ErrorCheck(info.Dump(*outDir + "/index.info")) log.Printf("\tsaved runtime info") misc.ErrorCheck(graphStore.Dump(*outDir + "/index.graph")) log.Printf("\tsaved groot graphs") + misc.ErrorCheck(database.Dump(*outDir + "/index.sigs")) + log.Printf("\tsaved MinHash signatures") log.Println("finished") } diff --git a/cmd/report.go b/cmd/report.go index 0039687..a41f1b9 100644 --- a/cmd/report.go +++ b/cmd/report.go @@ -98,7 +98,7 @@ func reportParamCheck() error { log.Printf("\tBAM file: %v", *bamFile) } if *covCutoff > 1.0 { - return fmt.Errorf("supplied coverage cutoff exceeds 1.0 (100%%): %.2f", *covCutoff) + return fmt.Errorf("supplied coverage cutoff exceeds 1.0 (100%%): %v", *covCutoff) } return nil } diff --git a/paper/benchmarking/accuracy/accuracy-test.go b/paper/benchmarking/accuracy/accuracy-test.go index 9b33434..b92aabd 100644 --- a/paper/benchmarking/accuracy/accuracy-test.go +++ b/paper/benchmarking/accuracy/accuracy-test.go @@ -3,13 +3,14 @@ package main import ( "flag" "fmt" - "github.com/biogo/hts/bam" - "github.com/biogo/hts/bgzf" - "github.com/biogo/hts/sam" "io" "log" "os" "strings" + + "github.com/biogo/hts/bam" + "github.com/biogo/hts/bgzf" + "github.com/biogo/hts/sam" ) var inputFile = flag.String("bam", "", "bam file to run accuracy test on") @@ -28,7 +29,7 @@ func main() { log.Fatalf("could not open bam file %q:", err) } if !ok { - log.Printf("file %q has no bgzf magic block: may be truncated", inputFile) + log.Printf("file %v has no bgzf magic block: may be truncated", inputFile) } r = f b, err := bam.NewReader(r, 0) diff --git a/src/alignment/alignment_test.go b/src/alignment/alignment_test.go index 2c544fd..ecb965f 100644 --- a/src/alignment/alignment_test.go +++ b/src/alignment/alignment_test.go @@ -1,8 +1,8 @@ package alignment import ( - "fmt" "io" + "log" "os" "testing" @@ -18,12 +18,12 @@ var ( sigSize = 128 ) -func loadGFA() (*gfa.GFA, error) { +func loadGFA() *gfa.GFA { // load the GFA file fh, err := os.Open(inputFile) reader, err := gfa.NewReader(fh) if err != nil { - return nil, fmt.Errorf("can't read gfa file: %v", err) + log.Fatalf("can't read gfa file: %v", err) } // collect the GFA instance myGFA := reader.CollectGFA() @@ -34,13 +34,13 @@ func loadGFA() (*gfa.GFA, error) { break } if err != nil { - return nil, fmt.Errorf("error reading line in gfa file: %v", err) + log.Fatalf("error reading line in gfa file: %v", err) } if err := line.Add(myGFA); err != nil { - return nil, fmt.Errorf("error adding line to GFA instance: %v", err) + log.Fatalf("error adding line to GFA instance: %v", err) } } - return myGFA, nil + return myGFA } func setupMultimapRead() (*seqio.FASTQread, error) { @@ -80,16 +80,13 @@ func TestExactMatchMultiMapper(t *testing.T) { // create the read testRead, err := setupMultimapRead() if err != nil { - t.Fatal(err) + log.Fatal(err) } // create the GrootGraph and graphStore - myGFA, err := loadGFA() - if err != nil { - t.Fatal(err) - } + myGFA := loadGFA() grootGraph, err := graph.CreateGrootGraph(myGFA, 1) if err != nil { - t.Fatal(err) + log.Fatal(err) } graphStore := make(graph.GraphStore) graphStore[grootGraph.GraphID] = grootGraph @@ -116,16 +113,13 @@ func TestExactMatchUniqMapper(t *testing.T) { // create the read testRead, err := setupUniqmapRead() if err != nil { - t.Fatal(err) + log.Fatal(err) } // create the GrootGraph and graphStore - myGFA, err := loadGFA() - if err != nil { - t.Fatal(err) - } + myGFA := loadGFA() grootGraph, err := graph.CreateGrootGraph(myGFA, 1) if err != nil { - t.Fatal(err) + log.Fatal(err) } graphStore := make(graph.GraphStore) graphStore[grootGraph.GraphID] = grootGraph diff --git a/src/graph/graph.go b/src/graph/graph.go index 26a3cd0..724167c 100644 --- a/src/graph/graph.go +++ b/src/graph/graph.go @@ -4,8 +4,8 @@ package graph import ( "bytes" "encoding/binary" - "encoding/gob" "fmt" + "io/ioutil" "os" "sort" "strconv" @@ -14,8 +14,9 @@ import ( "github.com/biogo/hts/sam" "github.com/will-rowe/gfa" - "github.com/will-rowe/groot/src/seqio" "github.com/will-rowe/groot/src/misc" + "github.com/will-rowe/groot/src/seqio" + "gopkg.in/vmihailenco/msgpack.v2" ) /* @@ -430,24 +431,20 @@ type GraphStore map[int]*GrootGraph // Dump is a method to save a GrootGraph to file func (graphStore *GraphStore) Dump(path string) error { - file, err := os.Create(path) - if err == nil { - encoder := gob.NewEncoder(file) - encoder.Encode(graphStore) + b, err := msgpack.Marshal(graphStore) + if err != nil { + return err } - file.Close() - return err + return ioutil.WriteFile(path, b, 0644) } // Load is a method to load a GrootGraph from file func (graphStore *GraphStore) Load(path string) error { - file, err := os.Open(path) - if err == nil { - decoder := gob.NewDecoder(file) - err = decoder.Decode(graphStore) + b, err := ioutil.ReadFile(path) + if err != nil { + return err } - file.Close() - return err + return msgpack.Unmarshal(b, graphStore) } // GetRefs is a method to convert all paths held in graphStore to sam.References diff --git a/src/graph/graph_test.go b/src/graph/graph_test.go index 4eea74a..5d4a270 100644 --- a/src/graph/graph_test.go +++ b/src/graph/graph_test.go @@ -1,11 +1,12 @@ package graph import ( - "fmt" - "github.com/will-rowe/gfa" "io" + "log" "os" "testing" + + "github.com/will-rowe/gfa" ) var ( @@ -17,12 +18,12 @@ var ( blaB10 = []byte("ATGAAAGGATTAAAAGGGCTATTGGTTCTGGCTTTAGGCTTTACAGGACTACAGGTTTTTGGGCAACAGAACCCTGATATTAAAATTGAAAAATTAAAAGATAATTTATACGTCTATACAACCTATAATACCTTCAAAGGAACTAAATATGCGGCTAATGCGGTATATATGGTAACCGATAAAGGAGTAGTGGTTATAGACTCTCCATGGGGAGAAGATAAATTTAAAAGTTTTACAGACGAGATTTATAAAAAGCACGGAAAGAAAGTTATCATGAACATTGCAACCCACTCTCATGATGATAGAGCCGGAGGTCTTGAATATTTTGGTAAACTAGGTGCAAAAACTTATTCTACTAAAATGACAGATTCTATTTTAGCAAAAGAGAATAAGCCAAGAGCAAAGTACACTTTTGATAATAATAAATCTTTTAAAGTAGGAAAGACTGAGTTTCAGGTTTATTATCCGGGAAAAGGTCATACAGCAGATAATGTGGTTGTGTGGTTTCCTAAAGACAAAGTATTAGTAGGAGGCTGCATTGTAAAAAGTGGTGATTCGAAAGACCTTGGGTTTATTGGGGAAGCTTATGTAAACGACTGGACACAGTCCATACACAACATTCAGCAGAAATTTCCCTATGTTCAGTATGTCGTTGCAGGTCATGACGACTGGAAAGATCAAACATCAATACAACATACACTGGATTTAATCAGTGAATATCAACAAAAACAAAAGGCTTCAAATTAA") ) -func loadGFA() (*gfa.GFA, error) { +func loadGFA() *gfa.GFA { // load the GFA file fh, err := os.Open(inputFile) reader, err := gfa.NewReader(fh) if err != nil { - return nil, fmt.Errorf("can't read gfa file: %v", err) + log.Fatalf("can't read gfa file: %v", err) } // collect the GFA instance myGFA := reader.CollectGFA() @@ -33,30 +34,30 @@ func loadGFA() (*gfa.GFA, error) { break } if err != nil { - return nil, fmt.Errorf("error reading line in gfa file: %v", err) + log.Fatalf("error reading line in gfa file: %v", err) } if err := line.Add(myGFA); err != nil { - return nil, fmt.Errorf("error adding line to GFA instance: %v", err) + log.Fatalf("error adding line to GFA instance: %v", err) } } - return myGFA, nil + return myGFA } -func loadMSA() (*gfa.GFA, error) { +func loadMSA() *gfa.GFA { // load the MSA msa, _ := gfa.ReadMSA(inputFile2) // convert the MSA to a GFA instance myGFA, err := gfa.MSA2GFA(msa) - return myGFA, err + if err != nil { + log.Fatal(err) + } + return myGFA } // test CreateGrootGraph func TestCreateGrootGraph(t *testing.T) { - myGFA, err := loadGFA() - if err != nil { - t.Fatal(err) - } - _, err = CreateGrootGraph(myGFA, 1) + myGFA := loadGFA() + _, err := CreateGrootGraph(myGFA, 1) if err != nil { t.Fatal(err) } @@ -64,10 +65,7 @@ func TestCreateGrootGraph(t *testing.T) { // test Graph2Seq func TestGraph2Seq(t *testing.T) { - myGFA, err := loadGFA() - if err != nil { - t.Fatal(err) - } + myGFA := loadGFA() grootGraph, err := CreateGrootGraph(myGFA, 1) if err != nil { t.Fatal(err) @@ -85,10 +83,8 @@ func TestGraph2Seq(t *testing.T) { // test WindowGraph func TestWindowGraph(t *testing.T) { - myGFA, err := loadMSA() - if err != nil { - t.Fatal(err) - } + myGFA := loadMSA() + //myGFA := loadGFA() grootGraph, err := CreateGrootGraph(myGFA, 1) if err != nil { t.Fatal(err) @@ -104,10 +100,7 @@ func TestWindowGraph(t *testing.T) { // test GraphStore dump/load func TestGraphStore(t *testing.T) { - myGFA, err := loadGFA() - if err != nil { - t.Fatal(err) - } + myGFA := loadGFA() grootGraph, err := CreateGrootGraph(myGFA, 1) if err != nil { t.Fatal(err) @@ -131,10 +124,7 @@ func TestGraphStore(t *testing.T) { // test DumpGraph to save a gfa func TestGraphDump(t *testing.T) { - myGFA, err := loadGFA() - if err != nil { - t.Fatal(err) - } + myGFA := loadGFA() grootGraph, err := CreateGrootGraph(myGFA, 1) if err != nil { t.Fatal(err) diff --git a/src/lshForest/lshForest.go b/src/lshForest/lshForest.go new file mode 100644 index 0000000..267a736 --- /dev/null +++ b/src/lshForest/lshForest.go @@ -0,0 +1,298 @@ +package lshForest + +import ( + "encoding/binary" + "fmt" + "io/ioutil" + "math" + "sort" + "sync" + + "github.com/will-rowe/groot/src/seqio" + "gopkg.in/vmihailenco/msgpack.v2" +) + +// set to 2/4/8 for 16bit/32bit/64bit hash values +const HASH_SIZE = 8 + +/* + The LSH forest +*/ +type LSHforest struct { + K int + L int + InitialHashTables []initialHashTable + hashTables []hashTables + hashedSignatureFunc hashedSignatureFunc + KeyLookup KeyLookupMap +} + +/* + A function to construct the LSH forest +*/ +func NewLSHforest(sigSize int, jsThresh float64) *LSHforest { + // calculate the optimal number of bands and hash functions to use based on the length of MinHash signature and a Jaccard Similarity theshhold + K, L, _, _ := optimise(sigSize, jsThresh) + // create the initial hash tables + InitialHashTables := make([]initialHashTable, L) + for i := range InitialHashTables { + InitialHashTables[i] = make(initialHashTable) + } + // create the hash tables that will be populated once the LSH forest indexing method has been run + indexedHashTables := make([]hashTables, L) + for i := range indexedHashTables { + indexedHashTables[i] = make(hashTables, 0) + } + // create the KeyLookup map to relate signatures to graph locations + KeyLookup := make(KeyLookupMap) + // return the address of the new LSH forest + newLSHforest := new(LSHforest) + newLSHforest.K = K + newLSHforest.L = L + newLSHforest.InitialHashTables = InitialHashTables + newLSHforest.hashTables = indexedHashTables + newLSHforest.hashedSignatureFunc = hashedSignatureFuncGen(HASH_SIZE) + newLSHforest.KeyLookup = KeyLookup + return newLSHforest +} + +/* + The types needed by the LSH forest +*/ +// this map relates the stringified seqio.Key to the original, allowing LSHforest search results to easily be related to graph locations +type KeyLookupMap map[string]seqio.Key + +// graphKeys is a slice containing all the stringified graphKeys for a given hashed signature +type graphKeys []string + +// the initial hash table uses the hashed signature as a key - the values are the corresponding graphKeys +type initialHashTable map[string]graphKeys + +// a band is a single hash table that is stored in the indexedHashTables - it contains the band of a hash signature and the corresponding graphKeys +type band struct { + HashedSignature string + graphKeys graphKeys +} + +// this is populated during indexing -- it is a slice of bands and can be sorted +type hashTables []band + +//methods to satisfy the sort interface +func (h hashTables) Len() int { return len(h) } +func (h hashTables) Swap(i, j int) { h[i], h[j] = h[j], h[i] } +func (h hashTables) Less(i, j int) bool { return h[i].HashedSignature < h[j].HashedSignature } + +// the hashkey function type and the generator function +type hashedSignatureFunc func([]uint64) string + +func hashedSignatureFuncGen(hashValueSize int) hashedSignatureFunc { + return func(sig []uint64) string { + hashedSig := make([]byte, hashValueSize*len(sig)) + buf := make([]byte, 8) + for i, v := range sig { + // use the ByteOrder interface to write binary data + // use the LittleEndian implementation and call the Put method + binary.LittleEndian.PutUint64(buf, v) + copy(hashedSig[i*hashValueSize:(i+1)*hashValueSize], buf[:hashValueSize]) + } + return string(hashedSig) + } +} + +/* + A method to return the number of hash functions and number of bands set by the LSH forest +*/ +func (self *LSHforest) Settings() (K, L int) { + return self.K, self.L +} + +/* + A method to add a minhash signature and graph key to the LSH forest +*/ +func (self *LSHforest) Add(key string, sig []uint64) { + // split the signature into the right number of bands and then hash each one + hashedSignature := make([]string, self.L) + for i := 0; i < self.L; i++ { + hashedSignature[i] = self.hashedSignatureFunc(sig[i*self.K : (i+1)*self.K]) + } + // iterate over each band in the LSH forest + for i := 0; i < len(self.InitialHashTables); i++ { + // if the current band in the signature isn't in the current band in the LSH forest, add it + if _, ok := self.InitialHashTables[i][hashedSignature[i]]; !ok { + self.InitialHashTables[i][hashedSignature[i]] = make(graphKeys, 1) + self.InitialHashTables[i][hashedSignature[i]][0] = key + // if it is, append the current key (graph location) to this hashed signature band + } else { + self.InitialHashTables[i][hashedSignature[i]] = append(self.InitialHashTables[i][hashedSignature[i]], key) + } + } +} + +/* + A method to index the graph (transfers contents of each initialHashTable so they can be sorted and searched) +*/ +func (self *LSHforest) Index() { + // iterate over the empty indexed hash tables + for i := range self.hashTables { + // transfer contents from the corresponding band in the initial hash table + for HashedSignature, keys := range self.InitialHashTables[i] { + self.hashTables[i] = append(self.hashTables[i], band{HashedSignature, keys}) + } + // sort the new hashtable and store it in the corresponding slot in the indexed hash tables + sort.Sort(self.hashTables[i]) + // clear the initial hashtable that has just been processed + self.InitialHashTables[i] = make(initialHashTable) + } +} + +/* + Methods to dump the LSH forest to disk and then load it again +*/ +// Dump an LSH index to disk +func (self *LSHforest) Dump(path string) error { + if len(self.hashTables[0]) != 0 { + return fmt.Errorf("cannot dump the LSH Forest after running the indexing method") + } + b, err := msgpack.Marshal(self) + if err != nil { + return err + } + return ioutil.WriteFile(path, b, 0644) +} + +// Load an LSH index from disk +func (self *LSHforest) Load(path string) error { + b, err := ioutil.ReadFile(path) + if err != nil { + return err + } + return msgpack.Unmarshal(b, self) +} + +/* + A method to query a MinHash signature against the LSH forest +*/ +func (self *LSHforest) Query(sig []uint64) []string { + result := make([]string, 0) + // more info on done chans for explicit cancellation in concurrent pipelines: https://blog.golang.org/pipelines + done := make(chan struct{}) + defer close(done) + // collect query results and aggregate in a single array to send back + for key := range self.runQuery(sig, done) { + result = append(result, key) + } + return result +} + +func (self *LSHforest) runQuery(sig []uint64, done <-chan struct{}) <-chan string { + queryResultChan := make(chan string) + go func() { + defer close(queryResultChan) + // hash the query signature + hashedSignature := make([]string, self.L) + for i := 0; i < self.L; i++ { + hashedSignature[i] = self.hashedSignatureFunc(sig[i*self.K : (i+1)*self.K]) + } + // don't send back multiple copies of the same key + seens := make(map[string]bool) + // compress internal nodes using a prefix + prefixSize := HASH_SIZE * self.K + // run concurrent hashtable queries + keyChan := make(chan string) + var wg sync.WaitGroup + wg.Add(self.L) + for i := 0; i < self.L; i++ { + go func(band hashTables, queryChunk string) { + defer wg.Done() + // sort.Search uses binary search to find and return the smallest index i in [0, n) at which f(i) is true + index := sort.Search(len(band), func(x int) bool { return band[x].HashedSignature[:prefixSize] >= queryChunk }) + // k is the index returned by the search + if index < len(band) && band[index].HashedSignature[:prefixSize] == queryChunk { + for j := index; j < len(band) && band[j].HashedSignature[:prefixSize] == queryChunk; j++ { + // copies key values from this hashtable to the keyChan until all values from band[j] copied or done is closed + for _, key := range band[j].graphKeys { + select { + case keyChan <- key: + case <-done: + return + } + } + } + } + }(self.hashTables[i], hashedSignature[i]) + } + go func() { + wg.Wait() + close(keyChan) + }() + for key := range keyChan { + if _, seen := seens[key]; seen { + continue + } + queryResultChan <- key + seens[key] = true + } + }() + return queryResultChan +} + +// the following funcs are taken from https://github.com/ekzhu/minhash-lsh + +// optimise returns the optimal number of hash functions and the optimal number of bands for Jaccard similarity search, as well as the false positive and negative probabilities. +func optimise(sigSize int, jsThresh float64) (int, int, float64, float64) { + optimumK, optimumL := 0, 0 + fp, fn := 0.0, 0.0 + minError := math.MaxFloat64 + for l := 1; l <= sigSize; l++ { + for k := 1; k <= sigSize; k++ { + if l*k > sigSize { + break + } + currFp := probFalsePositive(l, k, jsThresh, 0.01) + currFn := probFalseNegative(l, k, jsThresh, 0.01) + currErr := currFn + currFp + if minError > currErr { + minError = currErr + optimumK = k + optimumL = l + fp = currFp + fn = currFn + } + } + } + return optimumK, optimumL, fp, fn +} + +// Compute the integral of function f, lower limit a, upper limit l, and +// precision defined as the quantize step +func integral(f func(float64) float64, a, b, precision float64) float64 { + var area float64 + for x := a; x < b; x += precision { + area += f(x+0.5*precision) * precision + } + return area +} + +// Probability density function for false positive +func falsePositive(l, k int) func(float64) float64 { + return func(j float64) float64 { + return 1.0 - math.Pow(1.0-math.Pow(j, float64(k)), float64(l)) + } +} + +// Probability density function for false negative +func falseNegative(l, k int) func(float64) float64 { + return func(j float64) float64 { + return 1.0 - (1.0 - math.Pow(1.0-math.Pow(j, float64(k)), float64(l))) + } +} + +// Compute the cummulative probability of false negative given threshold t +func probFalseNegative(l, k int, t, precision float64) float64 { + return integral(falseNegative(l, k), t, 1.0, precision) +} + +// Compute the cummulative probability of false positive given threshold t +func probFalsePositive(l, k int, t, precision float64) float64 { + return integral(falsePositive(l, k), 0, t, precision) +} diff --git a/src/lshIndex/README.md b/src/lshIndex/README.md deleted file mode 100644 index 88757b4..0000000 --- a/src/lshIndex/README.md +++ /dev/null @@ -1,9 +0,0 @@ -# lshIndex package - -Since version 0.8.0, GROOT has two options for indexing the variation graphs - lshForest or lshEnsemble. With the addition of lshEnsemble, GROOT can now receive variable read lengths and seed these against graphs using containment search. - -This has required a significant re-write of the lshForest code, which is all contained in this directory. The majority of this code now comes from the [lshEnsemble package](https://godoc.org/github.com/ekzhu/lshensemble) by ekzhu. I have just made a few changes: - -* removed methods that were unnecessary for GROOT -* added a method to write the index to disk -* added methods to generate a single LSH Forest index using a Jaccard Similarity threshold and signature length parameter diff --git a/src/lshIndex/lshEnsemble.go b/src/lshIndex/lshEnsemble.go deleted file mode 100644 index 01bb0df..0000000 --- a/src/lshIndex/lshEnsemble.go +++ /dev/null @@ -1,159 +0,0 @@ -package lshIndex - -import ( - "fmt" - "io/ioutil" - "sync" - - "gopkg.in/vmihailenco/msgpack.v2" - "github.com/orcaman/concurrent-map" - "github.com/will-rowe/groot/src/seqio" -) - -type param struct { - k int - l int -} - -// Partition represents a domain size partition in the LSH Ensemble index. -type Partition struct { - Lower int `json:"lower"` - Upper int `json:"upper"` -} - -// KeyLookupMap relates the stringified seqio.Key to the original, allowing LSH index search results to easily be related to graph locations -type KeyLookupMap map[string]seqio.Key - -// GraphWindow represents a region of a variation graph -type GraphWindow struct { - // The unique key of this window - Key interface{} - // The window size - Size int - // The MinHash signature of this window - Signature []uint64 -} - -// LshEnsemble represents an LSH Ensemble index. -type LshEnsemble struct { - Partitions []Partition - Lshes []*LshForest - MaxK int - NumHash int - Indexed bool - SingleForest bool - KeyLookup KeyLookupMap - paramCache cmap.ConcurrentMap -} - -// Add a new domain to the index given its partition ID - the index of the partition. -// The added domain won't be searchable until the Index() function is called. -func (e *LshEnsemble) Add(key interface{}, sig []uint64, partInd int) { - e.Lshes[partInd].Add(key, sig) -} - -// Index makes all added domains searchable. -func (e *LshEnsemble) Index() { - for i := range e.Lshes { - e.Lshes[i].Index() - } - e.Indexed = true -} - -// Query returns the candidate domain keys in a channel. -// This function is given the MinHash signature of the query domain, sig, the domain size, -// the containment threshold, and a cancellation channel. -// Closing channel done will cancel the query execution. -// The query signature must be generated using the same seed as the signatures of the indexed domains, -// and have the same number of hash functions. -func (e *LshEnsemble) Query(sig []uint64, size int, threshold float64, done <-chan struct{}) <-chan interface{} { - if e.SingleForest { - return e.queryForest(sig, done) - } - params := e.computeParams(size, threshold) - return e.queryWithParam(sig, params, done) -} - -// -func (e *LshEnsemble) queryWithParam(sig []uint64, params []param, done <-chan struct{}) <-chan interface{} { - // Collect candidates from all partitions - keyChan := make(chan interface{}) - var wg sync.WaitGroup - wg.Add(len(e.Lshes)) - for i := range e.Lshes { - go func(lsh *LshForest, k, l int) { - lsh.Query(sig, k, l, keyChan, done) - wg.Done() - }(e.Lshes[i], params[i].k, params[i].l) - } - go func() { - wg.Wait() - close(keyChan) - }() - return keyChan -} - -// -func (e *LshEnsemble) queryForest(sig []uint64, done <-chan struct{}) <-chan interface{} { - keyChan := make(chan interface{}) - var wg sync.WaitGroup - wg.Add(1) - go func(lsh *LshForest) { - lsh.Query(sig, -1, -1, keyChan, done) - wg.Done() - }(e.Lshes[0]) - go func() { - wg.Wait() - close(keyChan) - }() - return keyChan -} - -// Compute the optimal k and l for each partition -func (e *LshEnsemble) computeParams(size int, threshold float64) []param { - params := make([]param, len(e.Partitions)) - for i, p := range e.Partitions { - x := p.Upper - key := cacheKey(x, size, threshold) - if cached, exist := e.paramCache.Get(key); exist { - params[i] = cached.(param) - } else { - optK, optL, _, _ := e.Lshes[i].OptimalKL(x, size, threshold) - computed := param{optK, optL} - e.paramCache.Set(key, computed) - params[i] = computed - } - } - return params -} - -// Make a cache key with threshold precision to 2 decimal points -func cacheKey(x, q int, t float64) string { - return fmt.Sprintf("%.8x %.8x %.2f", x, q, t) -} - -// Dump an LSH index to disk -func (LshEnsemble *LshEnsemble) Dump(path string) error { - if LshEnsemble.Indexed == true { - return fmt.Errorf("cannot dump the LSH Index after running the indexing method") - } - b, err := msgpack.Marshal(LshEnsemble) - if err != nil { - return err - } - return ioutil.WriteFile(path, b, 0644) -} - -// Load an LSH index from disk -func (LshEnsemble *LshEnsemble) Load(path string) error { - b, err := ioutil.ReadFile(path) - if err != nil { - return err - } - err = msgpack.Unmarshal(b, LshEnsemble) - if err != nil { - return err - } - LshEnsemble.Index() - return nil -} diff --git a/src/lshIndex/lshForest.go b/src/lshIndex/lshForest.go deleted file mode 100644 index c2ae107..0000000 --- a/src/lshIndex/lshForest.go +++ /dev/null @@ -1,206 +0,0 @@ -package lshIndex - -import ( - "encoding/binary" - "math" - "sort" -) - -// -type keys []interface{} - -// For initial bootstrapping -type initHashTable map[string]keys - -// -type bucket struct { - hashKey string - keys keys -} - -// -type hashTable []bucket -func (h hashTable) Len() int { return len(h) } -func (h hashTable) Swap(i, j int) { h[i], h[j] = h[j], h[i] } -func (h hashTable) Less(i, j int) bool { return h[i].hashKey < h[j].hashKey } - -// -type hashKeyFunc func([]uint64) string - -// -func hashKeyFuncGen(hashValueSize int) hashKeyFunc { - return func(sig []uint64) string { - s := make([]byte, hashValueSize*len(sig)) - buf := make([]byte, 8) - for i, v := range sig { - // use the ByteOrder interface to write binary data - // use the LittleEndian implementation and call the Put method - binary.LittleEndian.PutUint64(buf, v) - copy(s[i*hashValueSize:(i+1)*hashValueSize], buf[:hashValueSize]) - } - return string(s) - } -} - -// LshForest represents a MinHash LSH implemented using LSH Forest -// (http://ilpubs.stanford.edu:8090/678/1/2005-14.pdf). -// It supports query-time setting of the MinHash LSH parameters -// L (number of bands) and -// K (number of hash functions per band). -type LshForest struct { - K int - L int - InitHashTables []initHashTable - HashTables []hashTable - hashKeyFunc hashKeyFunc -} - -// -func newLshForest(k, l int) *LshForest { - if k < 0 || l < 0 { - panic("k and l must be positive") - } - hashTables := make([]hashTable, l) - initHashTables := make([]initHashTable, l) - for i := range initHashTables { - initHashTables[i] = make(initHashTable) - } - return &LshForest{ - K: k, - L: l, - InitHashTables: initHashTables, - HashTables: hashTables, - hashKeyFunc: hashKeyFuncGen(HASH_SIZE), - } -} - -// Returns the number of hash functions per band and the number of bands -func (f *LshForest) Settings() (int, int) { - return f.K, f.L -} - -// Add a key with MinHash signature into the index. -// The key won't be searchable until Index() is called. -func (f *LshForest) Add(key interface{}, sig []uint64) { - // Generate hash keys - Hs := make([]string, f.L) - for i := 0; i < f.L; i++ { - Hs[i] = f.hashKeyFunc(sig[i*f.K : (i+1)*f.K]) - } - // Insert keys into the bootstrapping tables - for i := range f.InitHashTables { - ht := f.InitHashTables[i] - hk := Hs[i] - if _, exist := ht[hk]; exist { - ht[hk] = append(ht[hk], key) - } else { - ht[hk] = make(keys, 1) - ht[hk][0] = key - } - } -} - -// Index makes all the keys added searchable. -func (f *LshForest) Index() { - for i := range f.HashTables { - ht := make(hashTable, 0, len(f.InitHashTables[i])) - // Build sorted hash table using buckets from init hash tables - for hashKey, keys := range f.InitHashTables[i] { - ht = append(ht, bucket{ - hashKey: hashKey, - keys: keys, - }) - } - sort.Sort(ht) - f.HashTables[i] = ht - // Reset the init hash tables - f.InitHashTables[i] = make(initHashTable) - } -} - -// Query returns candidate keys given the query signature and parameters. -func (f *LshForest) Query(sig []uint64, K, L int, out chan<- interface{}, done <-chan struct{}) { - if K == -1 { - K = f.K - } - if L == -1 { - L = f.L - } - prefixSize := HASH_SIZE * K - // Generate hash keys - Hs := make([]string, L) - for i := 0; i < L; i++ { - Hs[i] = f.hashKeyFunc(sig[i*f.K : i*f.K+K]) - } - seens := make(map[interface{}]bool) - for i := 0; i < L; i++ { - ht := f.HashTables[i] - hk := Hs[i] - k := sort.Search(len(ht), func(x int) bool { - return ht[x].hashKey[:prefixSize] >= hk - }) - if k < len(ht) && ht[k].hashKey[:prefixSize] == hk { - for j := k; j < len(ht) && ht[j].hashKey[:prefixSize] == hk; j++ { - for _, key := range ht[j].keys { - if _, seen := seens[key]; seen { - continue - } - seens[key] = true - select { - case out <- key: - case <-done: - return - } - } - } - } - } -} - -// OptimalKL returns the optimal K and L for containment search, -// and the false positive and negative probabilities. -// where x is the indexed domain size, q is the query domain size, -// and t is the containment threshold. -func (f *LshForest) OptimalKL(x, q int, t float64) (optK, optL int, fp, fn float64) { - minError := math.MaxFloat64 - for l := 1; l <= f.L; l++ { - for k := 1; k <= f.K; k++ { - currFp := probFalsePositiveC(x, q, l, k, t, PRECISION) - currFn := probFalseNegativeC(x, q, l, k, t, PRECISION) - currErr := currFn + currFp - if minError > currErr { - minError = currErr - optK = k - optL = l - fp = currFp - fn = currFn - } - } - } - return -} - -// optimise returns the optimal number of hash functions and the optimal number of bands for Jaccard similarity search, as well as the false positive and negative probabilities. -func optimise(sigSize int, jsThresh float64) (int, int, float64, float64) { - optimumNumHashFuncs, optimumNumBands := 0, 0 - fp, fn := 0.0, 0.0 - minError := math.MaxFloat64 - for l := 1; l <= sigSize; l++ { - for k := 1; k <= sigSize; k++ { - if l*k > sigSize { - break - } - currFp := probFalsePositive(l, k, jsThresh, PRECISION) - currFn := probFalseNegative(l, k, jsThresh, PRECISION) - currErr := currFn + currFp - if minError > currErr { - minError = currErr - optimumNumHashFuncs = k - optimumNumBands = l - fp = currFp - fn = currFn - } - } - } - return optimumNumHashFuncs, optimumNumBands, fp, fn -} diff --git a/src/lshIndex/lshIndex.go b/src/lshIndex/lshIndex.go deleted file mode 100644 index 3a14b06..0000000 --- a/src/lshIndex/lshIndex.go +++ /dev/null @@ -1,96 +0,0 @@ -package lshIndex - -import ( - //"errors" - "github.com/orcaman/concurrent-map" -) - -// set to 2/4/8 for 16bit/32bit/64bit hash values -const HASH_SIZE = 8 -// integration precision for optimising number of bands + hash functions in LSH Forest -const PRECISION = 0.01 -// number of partitions and maxK to use in LSH Ensemble (TODO: add these as customisable parameters for GROOT) -const PARTITIONS = 6 -const MAXK = 4 - -// error messages -//var ( - //querySizeError = errors.New("Query size is > +/- 10 bases of reference windows, re-index using --containment") -//) - -// NewLSHensemble initializes a new index consisting of MinHash LSH implemented using LshForest. -// numHash is the number of hash functions in MinHash. -// maxK is the maximum value for the MinHash parameter K - the number of hash functions per "band". -func NewLSHensemble(parts []Partition, numHash, maxK int) *LshEnsemble { - lshes := make([]*LshForest, len(parts)) - for i := range lshes { - lshes[i] = newLshForest(maxK, numHash/maxK) - } - return &LshEnsemble{ - Lshes: lshes, - Partitions: parts, - MaxK: maxK, - NumHash: numHash, - paramCache: cmap.New(), - } -} - -// NewLshForest initializes a new index consisting of MinHash LSH implemented using a single LshForest. -// sigSize is the number of hash functions in MinHash. -// jsThresh is the minimum Jaccard similarity needed for a query to return a match -func NewLSHforest(sigSize int, jsThresh float64) *LshEnsemble { - // calculate the optimal number of bands and hash functions to use - numHashFuncs, numBands, _, _ := optimise(sigSize, jsThresh) - lshes := make([]*LshForest, 1) - lshes[0] = newLshForest(numHashFuncs, numBands) - return &LshEnsemble{ - Lshes: lshes, - Partitions: make([]Partition, 1), - MaxK: numBands, - NumHash: numHashFuncs, - paramCache: cmap.New(), - SingleForest: true, - } -} - -// BoostrapLshEnsemble builds an index from a channel of domains. -// The returned index consists of MinHash LSH implemented using LshForest. -// numPart is the number of partitions to create. -// numHash is the number of hash functions in MinHash. -// maxK is the maximum value for the MinHash parameter K - the number of hash functions per "band". -// GraphWindow is a channel emitting windows (don't need to be sorted by their sizes as windows are constant) TODO: should probably add a check for this -func BootstrapLshEnsemble(numPart, numHash, maxK, totalNumWindows int, windows <-chan *GraphWindow) *LshEnsemble { - index := NewLSHensemble(make([]Partition, numPart), numHash, maxK) - bootstrap(index, totalNumWindows, windows) - return index -} - -// bootstrap -func bootstrap(index *LshEnsemble, totalNumWindows int, windows <-chan *GraphWindow) { - numPart := len(index.Partitions) - depth := totalNumWindows / numPart - var currDepth, currPart int - for rec := range windows { - index.Add(rec.Key, rec.Signature, currPart) - currDepth++ - index.Partitions[currPart].Upper = rec.Size - if currDepth >= depth && currPart < numPart-1 { - currPart++ - index.Partitions[currPart].Lower = rec.Size - currDepth = 0 - } - } - return -} - -// Windows2Chan is a utility function that converts a GraphWindow slice in memory to a GraphWindow channel. -func Windows2Chan(windows []*GraphWindow) <-chan *GraphWindow { - c := make(chan *GraphWindow, 1000) - go func() { - for _, w := range windows { - c <- w - } - close(c) - }() - return c -} \ No newline at end of file diff --git a/src/lshIndex/lshIndex_test.go b/src/lshIndex/lshIndex_test.go deleted file mode 100644 index 6d4e17d..0000000 --- a/src/lshIndex/lshIndex_test.go +++ /dev/null @@ -1,168 +0,0 @@ -// testing is incomplete, more to be added... -package lshIndex - -import ( - "fmt" - "os" - "testing" -) - -var ( - // test graph windows - entry1 = &GraphWindow{ - Key : fmt.Sprintf("g%dn%do%d", 1, 2, 3), - Size : 100, - Signature : []uint64{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, - } - entry2 = &GraphWindow{ - Key : fmt.Sprintf("g%dn%do%d", 1, 3, 1), - Size : 100, - Signature : []uint64{1, 4, 3, 4, 5, 5, 7, 4, 9, 10}, - } - entry3 = &GraphWindow{ - Key : fmt.Sprintf("g%dn%do%d", 3, 22, 2), - Size : 100, - Signature : []uint64{4, 4, 3, 4, 5, 6, 7, 4, 9, 4}, - } - entries = []*GraphWindow{entry1, entry2, entry3} - // LSH Forest parameters - jsThresh = 0.85 - // LSH Ensemble parameters - numPart = 4 - numHash = 10 - maxK = 4 - // query for LSH Forest - query1 = &GraphWindow{ - Key : fmt.Sprintf("g%dn%do%d", 1, 2, 3), - Size : 100, - Signature : []uint64{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, - } - // query for LSH Ensemble - query2 = &GraphWindow{ - Key : fmt.Sprintf("g%dn%do%d", 1, 2, 3), - Size : 50, - Signature : []uint64{1, 1, 3, 4, 5, 6, 7, 8, 9, 10}, - } -) - -// test the lshForest constructor, add a record and query it -func Test_lshForestConstructor(t *testing.T) { - index := NewLSHforest(len(entry1.Signature), jsThresh) - numHF, numBucks := index.Lshes[0].Settings() - t.Logf("\tnumber of hash functions per bucket: %d\n", numHF) - t.Logf("\tnumber of buckets: %d\n", numBucks) - index.Add(entry1.Key, entry1.Signature, 0) - index.Index() - done := make(chan struct{}) - defer close(done) - var check string - for result := range index.Query(query1.Signature, query1.Size, jsThresh, done) { - check = result.(string) - if check != entry1.Key { - t.Fatal() - } - } - if check == "" { - t.Fatal("no result from LSH Forest") - } -} - -// test the lshForest constructor and add a set of records, then query -func Test_lshForestBootstrap(t *testing.T) { - index := NewLSHforest(len(entry1.Signature), jsThresh) - for _, i := range entries { - index.Add(i.Key, i.Signature, 0) - } - if len(index.Partitions) != 1 || index.SingleForest != true { - t.Fatal() - } - index.Index() - done := make(chan struct{}) - defer close(done) - var check string - for result := range index.Query(query1.Signature, query1.Size, jsThresh, done) { - check = result.(string) - if check != entry1.Key { - t.Fatal("incorrect result returned from LSH Forest") - } - } - if check == "" { - t.Fatal("no result from LSH Forest") - } -} - -// test the lshForest dump and load methods -func Test_lshForestDump(t *testing.T) { - index := NewLSHforest(len(entry1.Signature), jsThresh) - for _, i := range entries { - index.Add(i.Key, i.Signature, 0) - } - if err := index.Dump("./lsh.index"); err != nil { - t.Fatal(err) - } - index2 := NewLSHforest(len(entry1.Signature), jsThresh) - if err := index2.Load("./lsh.index"); err != nil { - t.Fatal(err) - } - if err := os.Remove("./lsh.index"); err != nil { - t.Fatal(err) - } - done := make(chan struct{}) - defer close(done) - var check string - for result := range index2.Query(query1.Signature, query1.Size, jsThresh, done) { - check = result.(string) - if check != entry1.Key { - t.Fatal(check) - } - } - if check == "" { - t.Fatal("no result from LSH Forest") - } -} - - -// test the lshEnsemble constructor, add the records and query it -func Test_lshEnsembleBootstrap(t *testing.T) { - index := BootstrapLshEnsemble(numPart, numHash, maxK, len(entries), Windows2Chan(entries)) - index.Index() - done := make(chan struct{}) - defer close(done) - var check string - for result := range index.Query(query2.Signature, query2.Size, jsThresh, done) { - check = result.(string) - if check != entry1.Key { - t.Fatal("incorrect result returned from LSH Ensemble") - } - } - if check == "" { - t.Fatal("no result from LSH ensemble") - } -} - -// test the lshEnsemble dump and load methods -func Test_lshEnsembleDump(t *testing.T) { - index := BootstrapLshEnsemble(numPart, numHash, maxK, len(entries), Windows2Chan(entries)) - if err := index.Dump("./lsh.index"); err != nil { - t.Fatal(err) - } - index2 := NewLSHensemble(make([]Partition, numPart), numHash, maxK) - if err := index2.Load("./lsh.index"); err != nil { - t.Fatal(err) - } - if err := os.Remove("./lsh.index"); err != nil { - t.Fatal(err) - } - done := make(chan struct{}) - defer close(done) - var check string - for result := range index2.Query(query2.Signature, query2.Size, jsThresh, done) { - check = result.(string) - if check != entry1.Key { - t.Fatal() - } - } - if check == "" { - t.Fatal("no result from LSH ensemble") - } -} diff --git a/src/lshIndex/probability.go b/src/lshIndex/probability.go deleted file mode 100644 index a2c1e87..0000000 --- a/src/lshIndex/probability.go +++ /dev/null @@ -1,86 +0,0 @@ -// copy of https://github.com/ekzhu/lshensemble/blob/0322dae1f4d960f6fb3f9e6e2870786b9f4239ed/probability.go -package lshIndex - -import "math" - -// Compute the integral of function f, lower limit a, upper limit l, and -// precision defined as the quantize step -func integral(f func(float64) float64, a, b, precision float64) float64 { - var area float64 - for x := a; x < b; x += precision { - area += f(x+0.5*precision) * precision - } - return area -} - -/* - The following are using Jaccard similarity -*/ -// Probability density function for false positive -func falsePositive(l, k int) func(float64) float64 { - return func(j float64) float64 { - return 1.0 - math.Pow(1.0-math.Pow(j, float64(k)), float64(l)) - } -} - -// Probability density function for false negative -func falseNegative(l, k int) func(float64) float64 { - return func(j float64) float64 { - return 1.0 - (1.0 - math.Pow(1.0-math.Pow(j, float64(k)), float64(l))) - } -} - -// Compute the cummulative probability of false negative given threshold t -func probFalseNegative(l, k int, t, precision float64) float64 { - return integral(falseNegative(l, k), t, 1.0, precision) -} - -// Compute the cummulative probability of false positive given threshold t -func probFalsePositive(l, k int, t, precision float64) float64 { - return integral(falsePositive(l, k), 0, t, precision) -} - -/* - The following are using Jaccard containment TODO: consolidate these functions with the above -*/ -// Probability density function for false positive -func falsePositiveC(x, q, l, k int) func(float64) float64 { - return func(t float64) float64 { - return 1.0 - math.Pow(1.0-math.Pow(t/(1.0+float64(x)/float64(q)-t), float64(k)), float64(l)) - } -} - -// Probability density function for false negative -func falseNegativeC(x, q, l, k int) func(float64) float64 { - return func(t float64) float64 { - return 1.0 - (1.0 - math.Pow(1.0-math.Pow(t/(1.0+float64(x)/float64(q)-t), float64(k)), float64(l))) - } -} - -// Compute the cummulative probability of false negative -func probFalseNegativeC(x, q, l, k int, t, precision float64) float64 { - fn := falseNegativeC(x, q, l, k) - xq := float64(x) / float64(q) - if xq >= 1.0 { - return integral(fn, t, 1.0, precision) - } - if xq >= t { - return integral(fn, t, xq, precision) - } else { - return 0.0 - } -} - -// Compute the cummulative probability of false positive -func probFalsePositiveC(x, q, l, k int, t, precision float64) float64 { - fp := falsePositiveC(x, q, l, k) - xq := float64(x) / float64(q) - if xq >= 1.0 { - return integral(fp, 0.0, t, precision) - } - if xq >= t { - return integral(fp, 0.0, t, precision) - } else { - return integral(fp, 0.0, xq, precision) - } -} diff --git a/src/minhash/minhash.go b/src/minhash/minhash.go index 1d70c2e..4efe8d4 100644 --- a/src/minhash/minhash.go +++ b/src/minhash/minhash.go @@ -13,7 +13,7 @@ const CANONICAL = false // minHash struct contains all the minimum hash values for a sequence type minHash struct { - kSize int + kSize int signature []uint64 } @@ -64,7 +64,7 @@ func NewMinHash(kSize, sigSize int) *minHash { signature[i] = math.MaxUint64 } return &minHash{ - kSize: kSize, + kSize: kSize, signature: signature, } -} \ No newline at end of file +} diff --git a/src/minhash/minhash_test.go b/src/minhash/minhash_test.go index 2a052fd..0ef1ae1 100644 --- a/src/minhash/minhash_test.go +++ b/src/minhash/minhash_test.go @@ -5,9 +5,9 @@ import ( ) var ( - kSize = 11 - sigSize = 24 - sequence = []byte("ACTGCGTGCGTGAAACGTGCACGTGACGTG") + kSize = 11 + sigSize = 24 + sequence = []byte("ACTGCGTGCGTGAAACGTGCACGTGACGTG") sequence2 = []byte("TGACGCACGCACTTTGCACGTGCACTGCAC") ) @@ -55,7 +55,7 @@ func TestSimilarity(t *testing.T) { t.Fatal("incorrect JS calculation") } // make sure the method checks work - mh3 = NewMinHash(kSize, (sigSize+1)) + mh3 = NewMinHash(kSize, (sigSize + 1)) _ = mh3.Add(sequence) _, err = mh.Similarity(mh3.Signature()) if err == nil { diff --git a/src/misc/misc.go b/src/misc/misc.go index 59c1417..b763408 100644 --- a/src/misc/misc.go +++ b/src/misc/misc.go @@ -3,14 +3,15 @@ package misc import ( "encoding/binary" - "encoding/gob" "errors" + "io/ioutil" "log" "os" "strings" "github.com/spf13/cobra" "github.com/spf13/pflag" + "gopkg.in/vmihailenco/msgpack.v2" ) // a function to throw error to the log and exit the program @@ -85,27 +86,22 @@ type IndexInfo struct { SigSize int JSthresh float64 ReadLength int - Containment bool } // method to dump the info to file func (self *IndexInfo) Dump(path string) error { - file, err := os.Create(path) - if err == nil { - encoder := gob.NewEncoder(file) - encoder.Encode(self) + b, err := msgpack.Marshal(self) + if err != nil { + return err } - file.Close() - return err + return ioutil.WriteFile(path, b, 0644) } // method to load info from file func (self *IndexInfo) Load(path string) error { - file, err := os.Open(path) - if err == nil { - decoder := gob.NewDecoder(file) - err = decoder.Decode(self) + b, err := ioutil.ReadFile(path) + if err != nil { + return err } - file.Close() - return err + return msgpack.Unmarshal(b, self) } diff --git a/src/reporting/reporting.go b/src/reporting/reporting.go index 4cabab6..a1ea86c 100644 --- a/src/reporting/reporting.go +++ b/src/reporting/reporting.go @@ -169,14 +169,14 @@ func (proc *BAMreader) Run() { close(reportChan) }() - // this will clean up the ARG name so that we can use it as a filename - var replacer = strings.NewReplacer("/", "__", "\t", "__") - // collect the annotated ARGs for anno := range reportChan { // print info to stdout fmt.Printf("%v\t%d\t%d\t%v\n", anno.arg, anno.count, anno.length, anno.cigar) + // this will clean up the ARG name so that we can use it as a filename + var replacer = strings.NewReplacer("/", "__", "\t", "__") + // plot coverage for this gene if proc.Plot == true { covPlot, err := plot.New() diff --git a/src/seqio/seqio.go b/src/seqio/seqio.go index a3a71ba..ac94ab2 100644 --- a/src/seqio/seqio.go +++ b/src/seqio/seqio.go @@ -90,7 +90,7 @@ func (self *FASTQread) RevComplement() { } // method to split sequence to k-mers + get minhash signature -func (self *Sequence) RunMinHash(k int, sigSize int) ([]uint64, error){ +func (self *Sequence) RunMinHash(k int, sigSize int) ([]uint64, error) { // create the MinHash minhash := minhash.NewMinHash(k, sigSize) // use the add method to initate rolling ntHash and populate MinHash diff --git a/src/stream/stream.go b/src/stream/stream.go index 94714b3..4e7b5ed 100644 --- a/src/stream/stream.go +++ b/src/stream/stream.go @@ -8,19 +8,20 @@ import ( "compress/gzip" "errors" "fmt" + "log" + "os" + "strings" + "sync" + "time" + "github.com/biogo/hts/bam" "github.com/biogo/hts/sam" "github.com/will-rowe/groot/src/alignment" "github.com/will-rowe/groot/src/graph" - "github.com/will-rowe/groot/src/lshIndex" + "github.com/will-rowe/groot/src/lshForest" "github.com/will-rowe/groot/src/misc" "github.com/will-rowe/groot/src/seqio" "github.com/will-rowe/groot/src/version" - "log" - "os" - "strings" - "sync" - "time" ) const ( @@ -164,7 +165,6 @@ type FastqChecker struct { WindowSize int MinReadLength int MinQual int - Containment bool } func NewFastqChecker() *FastqChecker { @@ -207,10 +207,8 @@ func (proc *FastqChecker) Run() { meanRL := float64(lengthTotal) / float64(rawCount) log.Printf("\tmean read length: %.0f\n", meanRL) // check the length is within +/-10 bases of the graph window - if proc.Containment == false { - if meanRL < float64(proc.WindowSize-10) || meanRL > float64(proc.WindowSize+10) { - misc.ErrorCheck(fmt.Errorf("read length is too variable (> +/- 10 bases of graph window size), try re-indexing using the --containment option\n")) - } + if meanRL < float64(proc.WindowSize-10) || meanRL > float64(proc.WindowSize+10) { + misc.ErrorCheck(fmt.Errorf("mean read length is outside the graph window size (+/- 10 bases)\n")) } } @@ -221,10 +219,9 @@ type DbQuerier struct { process Input chan seqio.FASTQread Output chan seqio.FASTQread - Db *lshIndex.LshEnsemble + Db *lshForest.LSHforest CommandInfo *misc.IndexInfo GraphStore graph.GraphStore - Threshold float64 } func NewDbQuerier() *DbQuerier { @@ -250,13 +247,11 @@ func (proc *DbQuerier) Run() { read.RC = true // set RC flag so we can tell which orientation the read is in } // get signature for read - readMH, err := read.RunMinHash(proc.CommandInfo.Ksize, proc.CommandInfo.SigSize) + readSketch, err := read.RunMinHash(proc.CommandInfo.Ksize, proc.CommandInfo.SigSize) misc.ErrorCheck(err) - // query the LSH index - done := make(chan struct{}) - defer close(done) - for result := range proc.Db.Query(readMH, len(read.Seq), proc.Threshold, done) { - seed := proc.Db.KeyLookup[result.(string)] + // query the LSH forest + for _, result := range proc.Db.Query(readSketch) { + seed := proc.Db.KeyLookup[result] seed.RC = read.RC seeds = append(seeds, seed) } diff --git a/src/version/version.go b/src/version/version.go index debfbe0..f6f9779 100644 --- a/src/version/version.go +++ b/src/version/version.go @@ -1,3 +1,3 @@ package version -const VERSION = "0.8.4" +const VERSION = "0.8.5"