Skip to content

Commit

Permalink
Refactor set_chain
Browse files Browse the repository at this point in the history
  • Loading branch information
radupotop committed Apr 16, 2024
1 parent c166046 commit 0ef2cc2
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 11 deletions.
1 change: 0 additions & 1 deletion app/api/web.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ def application(request):
token_instance = storage.get_token(token)

if token_instance and token_instance.is_valid:
ipt.get_chain()
if not ipt.has_rule(src_ip):
ipt.add_rule(src_ip)
storage.log_access_request(src_ip, token_instance)
Expand Down
20 changes: 10 additions & 10 deletions app/backend/iptables.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,16 @@ class IPTables:
def __init__(self, config: ConfigReader):
self.config = config
self.filter_table = iptc.Table(iptc.Table.FILTER)
self.set_chain()

def set_chain(self):
"""
Assume the whitelist chain exists and set it.
This does NOT create the chain and will not error if it doesn't exist.
Use setup_whitelist_chain for that.
"""
log.info('Whitelist chain: %s', self.config.chain)
self.chain = iptc.Chain(self.filter_table, self.config.chain)

def setup_whitelist_chain(self):
"""
Expand Down Expand Up @@ -51,12 +61,6 @@ def setup_input_chain(self, set_policy_drop=False):
log.warning('Setting the INPUT chain Policy to DROP')
input_chain.set_policy(iptc.Policy.DROP)

def get_chain(self):
"""
Get the opensesame chain.
"""
self.chain = iptc.Chain(self.filter_table, self.config.chain)

def build_inbound_rule(
self, port: str, protocol: str, always_accept: bool = False
) -> iptc.Rule:
Expand Down Expand Up @@ -84,8 +88,6 @@ def add_rule(self, src_ip: str) -> bool:
Example:
iptables -A opensesame -s SRC_IP -j ACCEPT
"""
if not hasattr(self, 'chain'):
self.get_chain()
rule = iptc.Rule()
rule.src = parse_ip(src_ip)
rule.target = iptc.Target(rule, iptc.Policy.ACCEPT)
Expand All @@ -109,8 +111,6 @@ def delete_rule(self, src_ip: str) -> bool:
"""
Drop a rule from the opensesame chain.
"""
if not hasattr(self, 'chain'):
self.get_chain()
found_rules = self._lookup_rules(src_ip)
for rule in found_rules:
self.chain.delete_rule(rule)
Expand Down

0 comments on commit 0ef2cc2

Please sign in to comment.