Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Atomic nibble instead of mutex #1601

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 0 additions & 33 deletions khmer/_cpy_smallcountgraph.hh
Original file line number Diff line number Diff line change
Expand Up @@ -46,40 +46,7 @@ typedef struct {

static void khmer_smallcountgraph_dealloc(khmer_KSmallCountgraph_Object * obj);

static
PyObject *
smallcount_get_raw_tables(khmer_KSmallCountgraph_Object * self, PyObject * args)
{
SmallCountgraph * countgraph = self->countgraph;

khmer::Byte ** table_ptrs = countgraph->get_raw_tables();
std::vector<uint64_t> sizes = countgraph->get_tablesizes();

PyObject * raw_tables = PyList_New(sizes.size());
for (unsigned int i=0; i<sizes.size(); ++i) {
Py_buffer buffer;
int res = PyBuffer_FillInfo(&buffer, NULL, table_ptrs[i],
sizes[i] / 2 +1, 0,
PyBUF_FULL_RO);
if (res == -1) {
return NULL;
}
PyObject * buf = PyMemoryView_FromBuffer(&buffer);
if(!PyMemoryView_Check(buf)) {
return NULL;
}
PyList_SET_ITEM(raw_tables, i, buf);
}

return raw_tables;
}

static PyMethodDef khmer_smallcountgraph_methods[] = {
{
"get_raw_tables",
(PyCFunction)smallcount_get_raw_tables, METH_VARARGS,
"Get a list of the raw storage tables as memoryview objects."
},
{NULL, NULL, 0, NULL} /* sentinel */
};

Expand Down
31 changes: 0 additions & 31 deletions khmer/_khmer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2585,44 +2585,13 @@ nodegraph_update(khmer_KNodegraph_Object * me, PyObject * args)
Py_RETURN_NONE;
}

static
PyObject *
nodegraph_get_raw_tables(khmer_KNodegraph_Object * self, PyObject * args)
{
Nodegraph * countgraph = self->nodegraph;

khmer::Byte ** table_ptrs = countgraph->get_raw_tables();
std::vector<uint64_t> sizes = countgraph->get_tablesizes();

PyObject * raw_tables = PyList_New(sizes.size());
for (unsigned int i=0; i<sizes.size(); ++i) {
Py_buffer buffer;
int res = PyBuffer_FillInfo(&buffer, NULL, table_ptrs[i], sizes[i], 0,
PyBUF_FULL_RO);
if (res == -1) {
return NULL;
}
PyObject * buf = PyMemoryView_FromBuffer(&buffer);
if(!PyMemoryView_Check(buf)) {
return NULL;
}
PyList_SET_ITEM(raw_tables, i, buf);
}

return raw_tables;
}

static PyMethodDef khmer_nodegraph_methods[] = {
{
"update",
(PyCFunction) nodegraph_update, METH_VARARGS,
"a set update: update this nodegraph with all the entries from the other"
},
{
"get_raw_tables",
(PyCFunction) nodegraph_get_raw_tables, METH_VARARGS,
"Get a list of the raw tables as memoryview objects"
},
{NULL, NULL, 0, NULL} /* sentinel */
};

Expand Down
6 changes: 6 additions & 0 deletions lib/hashgraph.hh
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,12 @@ class Countgraph : public khmer::Hashgraph
public:
explicit Countgraph(WordLength ksize, std::vector<uint64_t> sizes)
: Hashgraph(ksize, new ByteStorage(sizes)) { } ;

// get access to raw tables.
Byte ** get_raw_tables()
{
return ((ByteStorage*)store)->get_raw_tables();
}
};

// Hashgraph-derived class with NibbleStorage.
Expand Down
6 changes: 0 additions & 6 deletions lib/hashtable.hh
Original file line number Diff line number Diff line change
Expand Up @@ -277,12 +277,6 @@ public:
void get_kmer_counts(const std::string &s,
std::vector<BoundedCounterType> &counts) const;

// get access to raw tables.
Byte ** get_raw_tables()
{
return store->get_raw_tables();
}

// find the minimum k-mer count in the given sequence
BoundedCounterType get_min_count(const std::string &s);

Expand Down
3 changes: 3 additions & 0 deletions lib/khmer.hh
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ private:\
className(const className&);\
const className& operator=(const className&)

#include <atomic>
#include <set>
#include <map>
#include <queue>
Expand Down Expand Up @@ -121,6 +122,8 @@ typedef unsigned short int BoundedCounterType;

// A single-byte type.
typedef unsigned char Byte;
using AtomicByte = std::atomic<uint8_t>;


typedef void (*CallbackFn)(const char * info, void * callback_data,
unsigned long long n_reads,
Expand Down
4 changes: 2 additions & 2 deletions lib/storage.cc
Original file line number Diff line number Diff line change
Expand Up @@ -872,7 +872,7 @@ void NibbleStorage::load(std::string infilename, WordLength& ksize)
_n_tables = (unsigned int) save_n_tables;
_occupied_bins = save_occupied_bins;

_counts = new Byte*[_n_tables];
_counts = new AtomicByte*[_n_tables];
for (unsigned int i = 0; i < _n_tables; i++) {
_counts[i] = NULL;
}
Expand All @@ -887,7 +887,7 @@ void NibbleStorage::load(std::string infilename, WordLength& ksize)
tablesize = save_tablesize;
_tablesizes.push_back(tablesize);

_counts[i] = new Byte[tablebytes];
_counts[i] = new AtomicByte[tablebytes];

unsigned long long loaded = 0;
while (loaded != tablebytes) {
Expand Down
54 changes: 26 additions & 28 deletions lib/storage.hh
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,6 @@ Contact: [email protected]
#ifndef STORAGE_HH
#define STORAGE_HH

#include <cassert>
#include <array>
#include <mutex>
using MuxGuard = std::lock_guard<std::mutex>;

namespace khmer
{
Expand All @@ -67,7 +63,6 @@ public:
virtual BoundedCounterType test_and_set_bits( HashIntoType khash ) = 0;
virtual void add(HashIntoType khash) = 0;
virtual const BoundedCounterType get_count(HashIntoType khash) const = 0;
virtual Byte ** get_raw_tables() = 0;

void set_use_bigcount(bool b);
bool get_use_bigcount();
Expand Down Expand Up @@ -214,13 +209,6 @@ public:
return 1;
}

// Writing to the tables outside of defined methods has undefined behavior!
// As such, this should only be used to return read-only interfaces
Byte ** get_raw_tables()
{
return _counts;
}

void update_from(const BitStorage&);
};

Expand All @@ -246,9 +234,8 @@ protected:
size_t _n_tables;
uint64_t _occupied_bins;
uint64_t _n_unique_kmers;
std::array<std::mutex, 32> mutexes;
static constexpr uint8_t _max_count{15};
Byte ** _counts;
AtomicByte ** _counts;

// Compute index into the table, this retrieves the correct byte
// which you then need to select the correct nibble from
Expand All @@ -271,8 +258,6 @@ public:
NibbleStorage(std::vector<uint64_t>& tablesizes) :
_tablesizes{tablesizes}, _occupied_bins{0}, _n_unique_kmers{0}
{
// to allow more than 32 tables increase the size of mutex pool
assert(_n_tables <= 32);
_allocate_counters();
}

Expand All @@ -293,13 +278,13 @@ public:
{
_n_tables = _tablesizes.size();

_counts = new Byte*[_n_tables];
_counts = new AtomicByte*[_n_tables];

for (size_t i = 0; i < _n_tables; i++) {
const uint64_t tablesize = _tablesizes[i];
const uint64_t tablebytes = tablesize / 2 + 1;

_counts[i] = new Byte[tablebytes];
_counts[i] = new AtomicByte[tablebytes];
memset(_counts[i], 0, tablebytes);
}
}
Expand All @@ -317,12 +302,12 @@ public:
bool is_new_kmer = false;

for (unsigned int i = 0; i < _n_tables; i++) {
MuxGuard g(mutexes[i]);
Byte* const table(_counts[i]);
AtomicByte* const table(_counts[i]);
const uint64_t idx = _table_index(khash, _tablesizes[i]);
const uint8_t mask = _mask(khash, _tablesizes[i]);
const uint8_t shift = _shift(khash, _tablesizes[i]);
const uint8_t current_count = (table[idx] & mask) >> shift;
uint8_t current_tbl = table[idx];
uint8_t current_count = (current_tbl & mask) >> shift;

if (!is_new_kmer) {
if (current_count == 0) {
Expand All @@ -342,8 +327,25 @@ public:
}

// increase count, no checking for overflow
const uint8_t new_count = (current_count + 1) << shift;
table[idx] = (table[idx] & ~mask) | (new_count & mask);
// current_tbl and new_tbl are the current and new bit packed values
// for the idx'th byte of the table.
// compare_exchange_weak will update the value of table[idx] if
// current_tbl is the current value (hasn't been changed by a
// different thread) if they differ the value actually stored
// in table[idx] is written to current_tbl so this is a
// compare-and-swap loop
uint8_t new_count = (current_count + 1) << shift;
uint8_t new_tbl = (current_tbl & ~mask) | (new_count & mask);

while(!table[idx].compare_exchange_weak(current_tbl, new_tbl)) {
current_count = (current_tbl & mask) >> shift;
new_count = (current_count + 1);
if (new_count > _max_count) {
break;
}
new_count <<= shift;
new_tbl = (current_tbl & ~mask) | (new_count & mask);
}
}

if (is_new_kmer) {
Expand All @@ -358,7 +360,7 @@ public:

// get the minimum count across all tables
for (unsigned int i = 0; i < _n_tables; i++) {
const Byte* table(_counts[i]);
const AtomicByte* table(_counts[i]);
const uint64_t idx = _table_index(khash, _tablesizes[i]);
const uint8_t mask = _mask(khash, _tablesizes[i]);
const uint8_t shift = _shift(khash, _tablesizes[i]);
Expand Down Expand Up @@ -391,10 +393,6 @@ public:
void save(std::string outfilename, WordLength ksize);
void load(std::string infilename, WordLength& ksize);

Byte ** get_raw_tables()
{
return _counts;
}
};


Expand Down
23 changes: 4 additions & 19 deletions tests/test_countgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,14 +209,11 @@ def test_get_raw_tables():


def test_get_raw_tables_smallcountgraph():
# for the same number of entries a SmallCountgraph uses ~half the memory
# of a normal Countgraph
# smallcountgraphs store individual counts packed into a byte, the raw
# tables probably do not give users what they expect (something that can be
# given to numpy.frombuffer)
ht = khmer.SmallCountgraph(20, 1e5, 4)
tables = ht.get_raw_tables()

for size, table in zip(ht.hashsizes(), tables):
assert isinstance(table, memoryview)
assert size // 2 + 1 == len(table)
assert not hasattr(ht, 'get_raw_tables')


def test_get_raw_tables_view():
Expand All @@ -229,18 +226,6 @@ def test_get_raw_tables_view():
assert sum(tab.tolist()) == 1


def test_get_raw_tables_view_smallcountgraph():
ht = khmer.SmallCountgraph(4, 1e5, 4)
tables = ht.get_raw_tables()
for tab in tables:
assert sum(tab.tolist()) == 0
ht.consume('AAAA')
# the actual count is 1 but stored in the first 4bits of a Byte
# and so becomes 16
for tab in tables:
assert sum(tab.tolist()) == int('00010000', 2)


@pytest.mark.huge
def test_toobig():
try:
Expand Down
12 changes: 4 additions & 8 deletions tests/test_nodegraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,15 +553,11 @@ def test_extract_unique_paths_4():


def test_get_raw_tables():
# nodegraphs store individual bits packed into a byte, the raw tables
# probably do not give users what they expect (something that can be
# given to numpy.frombuffer)
kh = khmer.Nodegraph(10, 1e6, 4)
kh.consume('ATGGAGAGAC')
kh.consume('AGTGGCGATG')
kh.consume('ATAGACAGGA')
tables = kh.get_raw_tables()

for size, table in zip(kh.hashsizes(), tables):
assert isinstance(table, memoryview)
assert size == len(table)
assert not hasattr(kh, 'get_raw_tables')


def test_simple_median():
Expand Down