Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

simplify compare definitions and other improvements #13

Merged
merged 1 commit into from
Dec 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
cmake_minimum_required(VERSION 3.20 FATAL_ERROR)

find_program(CCACHE_PROGRAM ccache)
if (CCACHE_PROGRAM)
set_property(GLOBAL PROPERTY RULE_LAUNCH_COMPILE "${CCACHE_PROGRAM}")
endif()

set(CPP_ORDERBOOK orderbook)
project(${CPP_ORDERBOOK} LANGUAGES CXX)

Expand Down
16 changes: 8 additions & 8 deletions include/orderbook.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@ class OrderBook {
OrderBook(NotificationInterface<Notification>& n, size_t price_level_pool_size = 16384, size_t order_pool_size = 16384)
: order_pool_(order_pool_size),
notification_(static_cast<Notification&>(n)),
bids_(PriceLevel<CmpGreater>(PriceType::Bid, price_level_pool_size)),
asks_(PriceLevel<CmpLess>(PriceType::Ask, price_level_pool_size)),
trigger_over_(PriceLevel<CmpGreater>(PriceType::Trigger, price_level_pool_size)),
trigger_under_(PriceLevel<CmpLess>(PriceType::Trigger, price_level_pool_size)),
bids_(PriceLevel<PriceType::Bid>(price_level_pool_size)),
asks_(PriceLevel<PriceType::Ask>(price_level_pool_size)),
trigger_over_(PriceLevel<PriceType::TriggerOver>(price_level_pool_size)),
trigger_under_(PriceLevel<PriceType::TriggerUnder>(price_level_pool_size)),
orders_(OrderMap()),
trig_orders_(OrderMap()){};

Expand All @@ -39,10 +39,10 @@ class OrderBook {
private:
pool::AdaptiveObjectPool<Order> order_pool_;

PriceLevel<CmpGreater> bids_;
PriceLevel<CmpLess> asks_;
PriceLevel<CmpGreater> trigger_over_;
PriceLevel<CmpLess> trigger_under_;
PriceLevel<PriceType::Bid> bids_;
PriceLevel<PriceType::Ask> asks_;
PriceLevel<PriceType::TriggerOver> trigger_over_;
PriceLevel<PriceType::TriggerUnder> trigger_under_;

OrderMap orders_;
OrderMap trig_orders_;
Expand Down
5 changes: 3 additions & 2 deletions include/orderqueue.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ class OrderQueue : public boost::intrusive::set_base_hook<boost::intrusive::opti
Order *head_ = nullptr;
Order *tail_ = nullptr;

Decimal price_;
Decimal total_qty_;
uint64_t size_ = 0;

Expand All @@ -30,11 +31,11 @@ class OrderQueue : public boost::intrusive::set_base_hook<boost::intrusive::opti
void remove(Order *o);
Decimal process(const TradeNotification &tn, const PostOrderFill &postFill, OrderID takerOrderID, Decimal qty);

Decimal price_;

friend bool operator<(const OrderQueue &a, const OrderQueue &b) { return a.price_ < b.price_; }
friend bool operator>(const OrderQueue &a, const OrderQueue &b) { return a.price_ > b.price_; }
friend bool operator==(const OrderQueue &a, const OrderQueue &b) { return a.price_ == b.price_; }

friend class PriceCompare;
};

struct PriceCompare {
Expand Down
44 changes: 11 additions & 33 deletions include/pricelevel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,58 +20,36 @@ class Compare {
using CmpGreater = boost::intrusive::compare<std::greater<>>;
using CmpLess = boost::intrusive::compare<std::less<>>;

template <class CompareType>
template <PriceType P>
class PriceLevel {
pool::AdaptiveObjectPool<OrderQueue> queue_pool_;

using CompareType = std::conditional_t<(P == PriceType::Bid || P == PriceType::TriggerOver), CmpGreater, CmpLess>;
using PriceTree = boost::intrusive::rbtree<OrderQueue, CompareType>;
PriceTree price_tree_;

PriceType price_type_;
PriceType price_type_ = P;
Decimal volume_;
uint64_t num_orders_ = 0;
uint64_t depth_ = 0;

public:
PriceLevel(PriceType price_type, size_t price_level_pool_size) : price_type_(price_type), queue_pool_(price_level_pool_size){};
PriceLevel(size_t price_level_pool_size) : queue_pool_(price_level_pool_size){};
uint64_t len();
uint64_t depth();
Decimal volume();
OrderQueue* getQueue();
[[nodiscard]] OrderQueue* getQueue();
[[nodiscard]] OrderQueue* getNextQueue(const Decimal& price);
[[nodiscard]] OrderQueue* largestLessThan(const Decimal& price);
[[nodiscard]] OrderQueue* smallestGreaterThan(const Decimal& price);

void append(Order* order);
void remove(Order* order);

Decimal processMarketOrder(const TradeNotification& tn, const PostOrderFill& pf, OrderID takerOrderID, Decimal qty, Flag flag);
Decimal processLimitOrder(const TradeNotification& tn, const PostOrderFill& pf, OrderID& takerOrderID, Decimal& price, Decimal qty, Flag& flag);

PriceTree& price_tree() { return price_tree_; }

OrderQueue* LargestLessThan(const Decimal& price) {
auto it = price_tree_.lower_bound(price, PriceCompare());
if (it != price_tree_.begin()) {
--it;
return &(*it);
}
return nullptr;
}

OrderQueue* SmallestGreaterThan(const Decimal& price) {
auto it = price_tree_.upper_bound(price, PriceCompare());
if (it != price_tree_.end()) {
return &(*it);
}
return nullptr;
}

OrderQueue* GetNextQueue(const Decimal& price) {
switch (price_type_) {
case PriceType::Bid:
return LargestLessThan(price);
case PriceType::Ask:
return SmallestGreaterThan(price);
default:
throw std::runtime_error("invalid call to GetQueue");
}
}
PriceTree& price_tree() { return price_tree_; };
};

} // namespace orderbook
Expand Down
4 changes: 3 additions & 1 deletion include/types.hpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#pragma once

#include <cstdint>
#include <cwchar>
#include <iostream>

#include "decimal.hpp"
Expand Down Expand Up @@ -44,7 +45,8 @@ std::ostream& operator<<(std::ostream& os, const OrderStatus& status);
enum class PriceType {
Bid,
Ask,
Trigger,
TriggerOver,
TriggerUnder,
};

std::ostream& operator<<(std::ostream& os, const PriceType& priceType);
Expand Down
2 changes: 1 addition & 1 deletion src/order.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
namespace orderbook {

Decimal Order::getPrice(PriceType pt) {
if (pt == PriceType::Trigger) {
if (pt == PriceType::TriggerOver || pt == PriceType::TriggerUnder) [[unlikely]] {
return trig_price;
}

Expand Down
74 changes: 53 additions & 21 deletions src/pricelevel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,23 +12,23 @@

namespace orderbook {

template <class CompareType>
uint64_t PriceLevel<CompareType>::len() {
template <PriceType P>
uint64_t PriceLevel<P>::len() {
return num_orders_;
}

template <class CompareType>
uint64_t PriceLevel<CompareType>::depth() {
template <PriceType P>
uint64_t PriceLevel<P>::depth() {
return depth_;
}

template <class CompareType>
Decimal PriceLevel<CompareType>::volume() {
template <PriceType P>
Decimal PriceLevel<P>::volume() {
return volume_;
}

template <class CompareType>
void PriceLevel<CompareType>::append(Order* order) {
template <PriceType P>
void PriceLevel<P>::append(Order* order) {
auto price = order->getPrice(price_type_);

auto it = price_tree_.find(price);
Expand All @@ -45,8 +45,8 @@ void PriceLevel<CompareType>::append(Order* order) {
q->append(order);
}

template <class CompareType>
void PriceLevel<CompareType>::remove(Order* order) {
template <PriceType P>
void PriceLevel<P>::remove(Order* order) {
auto price = order->getPrice(price_type_);

auto q = order->queue;
Expand All @@ -64,8 +64,8 @@ void PriceLevel<CompareType>::remove(Order* order) {
volume_ -= order->qty;
}

template <class CompareType>
OrderQueue* PriceLevel<CompareType>::getQueue() {
template <PriceType P>
OrderQueue* PriceLevel<P>::getQueue() {
auto q = price_tree_.begin();
if (q != price_tree_.end()) {
return &*q;
Expand All @@ -74,8 +74,8 @@ OrderQueue* PriceLevel<CompareType>::getQueue() {
return nullptr;
}

template <class CompareType>
Decimal PriceLevel<CompareType>::processMarketOrder(const TradeNotification& tn, const PostOrderFill& pf, OrderID takerOrderID, Decimal qty, Flag flag) {
template <PriceType P>
Decimal PriceLevel<P>::processMarketOrder(const TradeNotification& tn, const PostOrderFill& pf, OrderID takerOrderID, Decimal qty, Flag flag) {
// TODO: this won't work as pricelevel volumes aren't accounted for correctly
if ((flag & (AoN | FoK)) != 0 && qty > volume_) {
return uint64_t(0);
Expand All @@ -92,9 +92,8 @@ Decimal PriceLevel<CompareType>::processMarketOrder(const TradeNotification& tn,
return uint64_t(0);
};

template <class CompareType>
Decimal PriceLevel<CompareType>::processLimitOrder(const TradeNotification& tn, const PostOrderFill& pf, OrderID& takerOrderID, Decimal& price, Decimal qty,
Flag& flag) {
template <PriceType P>
Decimal PriceLevel<P>::processLimitOrder(const TradeNotification& tn, const PostOrderFill& pf, OrderID& takerOrderID, Decimal& price, Decimal qty, Flag& flag) {
Decimal qtyProcessed = {};
auto orderQueue = getQueue();

Expand Down Expand Up @@ -127,7 +126,7 @@ Decimal PriceLevel<CompareType>::processLimitOrder(const TradeNotification& tn,
break;
}
aQty -= orderQueue->totalQty();
orderQueue = GetNextQueue(orderQueue->price());
orderQueue = getNextQueue(orderQueue->price());
}
} else {
while (orderQueue != nullptr && price > orderQueue->price()) {
Expand All @@ -136,7 +135,7 @@ Decimal PriceLevel<CompareType>::processLimitOrder(const TradeNotification& tn,
break;
}
aQty -= orderQueue->totalQty();
orderQueue = GetNextQueue(orderQueue->price());
orderQueue = getNextQueue(orderQueue->price());
}
}

Expand All @@ -157,7 +156,40 @@ Decimal PriceLevel<CompareType>::processLimitOrder(const TradeNotification& tn,
return qtyProcessed;
};

template class PriceLevel<CmpGreater>;
template class PriceLevel<CmpLess>;
template <PriceType P>
OrderQueue* PriceLevel<P>::largestLessThan(const Decimal& price) {
auto it = price_tree_.lower_bound(price, PriceCompare());
if (it != price_tree_.begin()) {
--it;
return &(*it);
}
return nullptr;
}

template <PriceType P>
OrderQueue* PriceLevel<P>::smallestGreaterThan(const Decimal& price) {
auto it = price_tree_.upper_bound(price, PriceCompare());
if (it != price_tree_.end()) {
return &(*it);
}
return nullptr;
}

template <PriceType P>
OrderQueue* PriceLevel<P>::getNextQueue(const Decimal& price) {
switch (price_type_) {
case PriceType::Bid:
return largestLessThan(price);
case PriceType::Ask:
return smallestGreaterThan(price);
default:
throw std::runtime_error("invalid call to GetQueue");
}
}

template class PriceLevel<PriceType::Bid>;
template class PriceLevel<PriceType::Ask>;
template class PriceLevel<PriceType::TriggerOver>;
template class PriceLevel<PriceType::TriggerUnder>;

} // namespace orderbook
6 changes: 4 additions & 2 deletions src/types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,10 @@ std::ostream& operator<<(std::ostream& os, const PriceType& priceType) {
return os << "Bid";
case PriceType::Ask:
return os << "Ask";
case PriceType::Trigger:
return os << "Trigger";
case PriceType::TriggerOver:
return os << "TriggerOver";
case PriceType::TriggerUnder:
return os << "TriggerUnder";
}
return os << "Unknown";
}
Expand Down
39 changes: 21 additions & 18 deletions test/pricelevel_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class PriceLevelTest : public ::testing::Test {
};

TEST_F(PriceLevelTest, TestPriceLevel) {
PriceLevel<CmpLess> bidLevel(PriceType::Bid, 10);
PriceLevel<PriceType::Bid> bidLevel(10);

auto o1 = std::make_shared<Order>(1, Type::Limit, Side::Buy, Decimal(10, 0), Decimal(10, 0), Decimal(uint64_t(0)), Flag::None);
auto o2 = std::make_shared<Order>(2, Type::Limit, Side::Buy, Decimal(10, 0), Decimal(20, 0), Decimal(uint64_t(0)), Flag::None);
Expand All @@ -43,9 +43,10 @@ TEST_F(PriceLevelTest, TestPriceLevel) {
ASSERT_EQ(bidLevel.depth(), 2);
ASSERT_EQ(bidLevel.len(), 2);

if (tree.begin()->head() != o1.get() || tree.begin()->tail() != o1.get() || tree.rbegin()->head() != o2.get() || tree.rbegin()->tail() != o2.get()) {
FAIL() << "invalid price levels";
}
ASSERT_EQ(tree.begin()->head(), o2.get()) << "Invalid price levels: head of the first element does not match o1";
ASSERT_EQ(tree.begin()->tail(), o2.get()) << "Invalid price levels: tail of the first element does not match o1";
ASSERT_EQ(tree.rbegin()->head(), o1.get()) << "Invalid price levels: head of the last element does not match o2";
ASSERT_EQ(tree.rbegin()->tail(), o1.get()) << "Invalid price levels: tail of the last element does not match o2";

bidLevel.remove(o1.get());

Expand All @@ -62,7 +63,7 @@ TEST_F(PriceLevelTest, TestPriceLevel) {
}

TEST_F(PriceLevelTest, TestPriceFinding) {
PriceLevel<CmpLess> askLevel(PriceType::Ask, 10);
PriceLevel<PriceType::Ask> askLevel(10);

askLevel.append(new Order(1, Type::Limit, Side::Sell, Decimal(5, 0), Decimal(130, 0), Decimal(uint64_t(0)), Flag::None));
askLevel.append(new Order(2, Type::Limit, Side::Sell, Decimal(5, 0), Decimal(170, 0), Decimal(uint64_t(0)), Flag::None));
Expand All @@ -75,17 +76,17 @@ TEST_F(PriceLevelTest, TestPriceFinding) {

ASSERT_EQ(askLevel.volume(), Decimal(40, 0));

ASSERT_EQ(askLevel.LargestLessThan(Decimal(101, 0))->price(), Decimal(100, 0));
ASSERT_EQ(askLevel.LargestLessThan(Decimal(150, 0))->price(), Decimal(140, 0));
ASSERT_EQ(askLevel.LargestLessThan(Decimal(100, 0)), nullptr);
ASSERT_EQ(askLevel.largestLessThan(Decimal(101, 0))->price(), Decimal(100, 0));
ASSERT_EQ(askLevel.largestLessThan(Decimal(150, 0))->price(), Decimal(140, 0));
ASSERT_EQ(askLevel.largestLessThan(Decimal(100, 0)), nullptr);

ASSERT_EQ(askLevel.SmallestGreaterThan(Decimal(169, 0))->price(), Decimal(170, 0));
ASSERT_EQ(askLevel.SmallestGreaterThan(Decimal(150, 0))->price(), Decimal(160, 0));
ASSERT_EQ(askLevel.SmallestGreaterThan(Decimal(170, 0)), nullptr);
ASSERT_EQ(askLevel.smallestGreaterThan(Decimal(169, 0))->price(), Decimal(170, 0));
ASSERT_EQ(askLevel.smallestGreaterThan(Decimal(150, 0))->price(), Decimal(160, 0));
ASSERT_EQ(askLevel.smallestGreaterThan(Decimal(170, 0)), nullptr);
}

TEST_F(PriceLevelTest, TestStopQueuePriceFinding) {
PriceLevel<CmpLess> trigLevel(PriceType::Trigger, 10);
PriceLevel<PriceType::TriggerUnder> trigLevel(10);

trigLevel.append(new Order(1, Type::Limit, Side::Sell, Decimal(5, 0), Decimal(10, 0), Decimal(130, 0), Flag::None));
trigLevel.append(new Order(2, Type::Limit, Side::Sell, Decimal(5, 0), Decimal(20, 0), Decimal(170, 0), Flag::None));
Expand All @@ -98,13 +99,15 @@ TEST_F(PriceLevelTest, TestStopQueuePriceFinding) {

ASSERT_EQ(trigLevel.volume(), Decimal(40, 0));

ASSERT_EQ(trigLevel.LargestLessThan(Decimal(101, 0))->price(), Decimal(100, 0));
ASSERT_EQ(trigLevel.LargestLessThan(Decimal(150, 0))->price(), Decimal(140, 0));
ASSERT_EQ(trigLevel.LargestLessThan(Decimal(100, 0)), nullptr);
std::cout << 1 << std::endl;
ASSERT_EQ(trigLevel.largestLessThan(Decimal(101, 0))->price(), Decimal(100, 0));
std::cout << 2 << std::endl;
ASSERT_EQ(trigLevel.largestLessThan(Decimal(150, 0))->price(), Decimal(140, 0));
ASSERT_EQ(trigLevel.largestLessThan(Decimal(100, 0)), nullptr);

ASSERT_EQ(trigLevel.SmallestGreaterThan(Decimal(169, 0))->price(), Decimal(170, 0));
ASSERT_EQ(trigLevel.SmallestGreaterThan(Decimal(150, 0))->price(), Decimal(160, 0));
ASSERT_EQ(trigLevel.SmallestGreaterThan(Decimal(170, 0)), nullptr);
ASSERT_EQ(trigLevel.smallestGreaterThan(Decimal(169, 0))->price(), Decimal(170, 0));
ASSERT_EQ(trigLevel.smallestGreaterThan(Decimal(150, 0))->price(), Decimal(160, 0));
ASSERT_EQ(trigLevel.smallestGreaterThan(Decimal(170, 0)), nullptr);
}

} // namespace test
Expand Down