diff --git a/pom.xml b/pom.xml index ba2a61a..a87d7ab 100644 --- a/pom.xml +++ b/pom.xml @@ -217,6 +217,11 @@ mssql-jdbc 8.2.2.jre8 + + dev.failsafe + failsafe + 3.0.0 + diff --git a/src/main/java/org/casbin/adapter/JDBCAdapter.java b/src/main/java/org/casbin/adapter/JDBCAdapter.java index cd80be5..86af2ae 100644 --- a/src/main/java/org/casbin/adapter/JDBCAdapter.java +++ b/src/main/java/org/casbin/adapter/JDBCAdapter.java @@ -14,6 +14,7 @@ package org.casbin.adapter; +import dev.failsafe.Failsafe; import org.casbin.jcasbin.exception.CasbinAdapterException; import org.casbin.jcasbin.model.Model; import org.casbin.jcasbin.persist.FilteredAdapter; @@ -99,7 +100,11 @@ public boolean isFiltered() { * loadFilteredPolicyFile loads only policy rules that match the filter from file. */ private void loadFilteredPolicyFile(Model model, Filter filter, Helper.loadPolicyLineHandler handler) throws CasbinAdapterException { - try (Statement stmt = conn.createStatement()) { + Failsafe.with(retryPolicy).run(ctx -> { + if (ctx.isRetry()) { + retry(ctx); + } + Statement stmt = conn.createStatement(); ResultSet rSet = stmt.executeQuery("SELECT * FROM casbin_rule"); ResultSetMetaData rData = rSet.getMetaData(); while (rSet.next()) { @@ -127,10 +132,7 @@ private void loadFilteredPolicyFile(Model model, Filter filter, Helper.loadPolic loadPolicyLine(line, model); } rSet.close(); - } catch (SQLException e) { - e.printStackTrace(); - throw new Error(e); - } + }); } /** diff --git a/src/main/java/org/casbin/adapter/JDBCBaseAdapter.java b/src/main/java/org/casbin/adapter/JDBCBaseAdapter.java index 1f33abd..61415c8 100644 --- a/src/main/java/org/casbin/adapter/JDBCBaseAdapter.java +++ b/src/main/java/org/casbin/adapter/JDBCBaseAdapter.java @@ -14,8 +14,10 @@ package org.casbin.adapter; +import dev.failsafe.ExecutionContext; +import dev.failsafe.Failsafe; +import dev.failsafe.RetryPolicy; import org.apache.commons.collections.CollectionUtils; -import org.casbin.jcasbin.main.CoreEnforcer; import org.casbin.jcasbin.model.Assertion; import org.casbin.jcasbin.model.Model; import org.casbin.jcasbin.persist.Adapter; @@ -23,6 +25,7 @@ import javax.sql.DataSource; import java.sql.*; +import java.time.Duration; import java.util.*; class CasbinRule { @@ -45,9 +48,11 @@ public String[] toStringArray() { * It can load policy from JDBC supported database or save policy to it. */ abstract class JDBCBaseAdapter implements Adapter { - private DataSource dataSource; + protected static final int _DEFAULT_CONNECTION_TRIES = 3; + protected DataSource dataSource; private final int batchSize = 1000; protected Connection conn; + protected RetryPolicy retryPolicy; /** * JDBCAdapter is the constructor for JDBCAdapter. @@ -73,6 +78,11 @@ protected JDBCBaseAdapter(DataSource dataSource) throws Exception { protected void migrate() throws SQLException { + retryPolicy = RetryPolicy.builder() + .handle(SQLException.class) + .withDelay(Duration.ofSeconds(1)) + .withMaxRetries(_DEFAULT_CONNECTION_TRIES) + .build(); conn = dataSource.getConnection(); Statement stmt = conn.createStatement(); String sql = "CREATE TABLE IF NOT EXISTS casbin_rule(id int NOT NULL PRIMARY KEY auto_increment, ptype VARCHAR(100) NOT NULL, v0 VARCHAR(100), v1 VARCHAR(100), v2 VARCHAR(100), v3 VARCHAR(100), v4 VARCHAR(100), v5 VARCHAR(100))"; @@ -168,7 +178,11 @@ protected void loadPolicyLine(CasbinRule line, Model model) { */ @Override public void loadPolicy(Model model) { - try (Statement stmt = conn.createStatement()) { + Failsafe.with(retryPolicy).run(ctx -> { + if (ctx.isRetry()) { + retry(ctx); + } + Statement stmt = conn.createStatement(); ResultSet rSet = stmt.executeQuery("SELECT * FROM casbin_rule"); ResultSetMetaData rData = rSet.getMetaData(); while (rSet.next()) { @@ -193,10 +207,7 @@ public void loadPolicy(Model model) { loadPolicyLine(line, model); } rSet.close(); - } catch (SQLException e) { - e.printStackTrace(); - throw new Error(e); - } + }); } private CasbinRule savePolicyLine(String ptype, List rule) { @@ -233,7 +244,10 @@ public void savePolicy(Model model) { String cleanSql = "delete from casbin_rule"; String addSql = "INSERT INTO casbin_rule (ptype,v0,v1,v2,v3,v4,v5) VALUES(?,?,?,?,?,?,?)"; - try { + Failsafe.with(retryPolicy).run(ctx -> { + if (ctx.isRetry()) { + retry(ctx); + } conn.setAutoCommit(false); int count = 0; @@ -295,10 +309,7 @@ public void savePolicy(Model model) { } finally { conn.setAutoCommit(true); } - } catch (SQLException e) { - e.printStackTrace(); - throw new Error(e); - } + }); } /** @@ -310,7 +321,11 @@ public void addPolicy(String sec, String ptype, List rule) { String sql = "INSERT INTO casbin_rule (ptype,v0,v1,v2,v3,v4,v5) VALUES(?,?,?,?,?,?,?)"; - try(PreparedStatement ps = conn.prepareStatement(sql)) { + Failsafe.with(retryPolicy).run(ctx -> { + if (ctx.isRetry()) { + retry(ctx); + } + PreparedStatement ps = conn.prepareStatement(sql); CasbinRule line = savePolicyLine(ptype, rule); ps.setString(1, line.ptype); @@ -322,10 +337,7 @@ public void addPolicy(String sec, String ptype, List rule) { ps.setString(7, line.v5); ps.addBatch(); ps.executeBatch(); - } catch (SQLException e) { - e.printStackTrace(); - throw new Error(e); - } + }); } /** @@ -344,26 +356,26 @@ public void removePolicy(String sec, String ptype, List rule) { public void removeFilteredPolicy(String sec, String ptype, int fieldIndex, String... fieldValues) { List values = Optional.of(Arrays.asList(fieldValues)).orElse(new ArrayList<>()); if (CollectionUtils.isEmpty(values)) return; - String sql = "DELETE FROM casbin_rule WHERE ptype = ?"; - int columnIndex = fieldIndex; - for (int i = 0; i < values.size(); i++) { - sql = String.format("%s%s%s%s", sql, " AND v", columnIndex, " = ?"); - columnIndex++; - } - try (PreparedStatement ps = conn.prepareStatement(sql)) { + Failsafe.with(retryPolicy).run(ctx -> { + if (ctx.isRetry()) { + retry(ctx); + } + String sql = "DELETE FROM casbin_rule WHERE ptype = ?"; + int columnIndex = fieldIndex; + for (int i = 0; i < values.size(); i++) { + sql = String.format("%s%s%s%s", sql, " AND v", columnIndex, " = ?"); + columnIndex++; + } + PreparedStatement ps = conn.prepareStatement(sql); ps.setString(1, ptype); for (int j = 0; j < values.size(); j++) { ps.setString(j + 2, values.get(j)); } ps.addBatch(); - ps.executeBatch(); - } catch (SQLException e) { - e.printStackTrace(); - throw new Error(e); - } + }); } /** @@ -372,4 +384,12 @@ public void removeFilteredPolicy(String sec, String ptype, int fieldIndex, Strin public void close() throws SQLException { conn.close(); } + + protected void retry(ExecutionContext ctx) throws SQLException { + if (ctx.getExecutionCount() < _DEFAULT_CONNECTION_TRIES) { + conn = dataSource.getConnection(); + } else { + throw new Error(ctx.getLastFailure()); + } + } }