Skip to content

Commit

Permalink
refactor code
Browse files Browse the repository at this point in the history
  • Loading branch information
psych0d0g committed Dec 17, 2023
1 parent 6f5ea80 commit 0007771
Showing 1 changed file with 77 additions and 74 deletions.
151 changes: 77 additions & 74 deletions verify_pw.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,37 +47,15 @@ func main() {
}

// Read database configuration from environment variables
dbHost := os.Getenv("PAPERLESS_DBHOST")
dbPort := os.Getenv("PAPERLESS_DBPORT")
dbName := os.Getenv("PAPERLESS_DBNAME")
dbUser := os.Getenv("DB_USER")
dbPassword := os.Getenv("PAPERLESS_DBPASS")
dbEngine := os.Getenv("PAPERLESS_DBENGINE")

dbHost, dbPort, dbName, dbUser, dbPassword, dbEngine := getDatabaseConfig()
if dbHost == "" || dbPort == "" || dbName == "" || dbUser == "" || dbPassword == "" || dbEngine == "" {
debugPrint("Database configuration is incomplete.")
fmt.Print(failedAuthFatalError)
os.Exit(1)
}

var connStr string
var db *sql.DB
var err error

switch dbEngine {
case "postgres":
connStr = fmt.Sprintf("host=%s port=%s user=%s password=%s dbname=%s sslmode=disable",
dbHost, dbPort, dbUser, dbPassword, dbName)
db, err = sql.Open("postgres", connStr)
case "mysql", "mariadb":
connStr = fmt.Sprintf("%s:%s@tcp(%s:%s)/%s", dbUser, dbPassword, dbHost, dbPort, dbName)
db, err = sql.Open("mysql", connStr)
default:
debugPrint(fmt.Sprintf("Unsupported database engine: %s", dbEngine))
fmt.Print(failedAuthFatalError)
os.Exit(1)
}

// Open database connection
db, err := openDatabaseConnection(dbEngine, dbHost, dbPort, dbUser, dbPassword, dbName)
if err != nil {
debugPrint(fmt.Sprintf("Failed to open database connection: %v", err))
fmt.Print(failedAuthFatalError)
Expand All @@ -86,27 +64,12 @@ func main() {
defer db.Close()

// Query the database for the stored credentials
var query string
var queryArgs []interface{}

switch dbEngine {
case "postgres":
query = "SELECT username, password FROM auth_user WHERE username = $1"
queryArgs = []interface{}{username}
case "mysql", "mariadb":
query = "SELECT username, password FROM auth_user WHERE username = ?"
queryArgs = []interface{}{username}
default:
debugPrint(fmt.Sprintf("Unsupported database engine: %s", dbEngine))
fmt.Print(failedAuthFatalError)
os.Exit(1) // Exit with code 1 for fatal error
}

query, queryArgs := getDatabaseQuery(dbEngine, username)
rows, err := db.Query(query, queryArgs...)
if err != nil {
debugPrint(fmt.Sprintf("Failed to execute database query: %v", err))
fmt.Print(failedAuthFatalError)
os.Exit(1) // Exit with code 1 for fatal error
os.Exit(1)
}
defer rows.Close()

Expand All @@ -115,9 +78,7 @@ func main() {
var dbUsername, dbPassword string
err := rows.Scan(&dbUsername, &dbPassword)
if err != nil {
debugPrint(fmt.Sprintf("Failed to scan database row: %v", err))
fmt.Print(failedAuthFatalError)
os.Exit(1)
handleError("Failed to scan database row", failedAuthFatalError)
}

// Extract parameters from the stored hash
Expand All @@ -126,9 +87,7 @@ func main() {
// Extract iterations from the stored hash
iterations, err := strconv.Atoi(parts[1])
if err != nil {
debugPrint(fmt.Sprintf("Failed to convert iterations to integer: %v", err))
fmt.Print(failedAuthFatalError)
os.Exit(1)
handleError("Failed to convert iterations to integer", failedAuthFatalError)
}

salt := parts[2]
Expand All @@ -143,44 +102,88 @@ func main() {
// Compare the generated hash with the stored hash
if encodedHashedInput == hashedPassword {
// Get the current user's UID and GID
currentUser, err := user.Current()
if err != nil {
debugPrint(fmt.Sprintf("Failed to get current user information: %v", err))
fmt.Print(failedAuthFatalError)
os.Exit(1)
}

uid, err := strconv.Atoi(currentUser.Uid)
if err != nil {
debugPrint(fmt.Sprintf("Failed to convert UID to integer: %v", err))
fmt.Print(failedAuthFatalError)
os.Exit(1)
}

gid, err := strconv.Atoi(currentUser.Gid)
if err != nil {
debugPrint(fmt.Sprintf("Failed to convert GID to integer: %v", err))
fmt.Print(failedAuthFatalError)
os.Exit(1)
}
uid, gid := getCurrentUserIDs()

// Get the consumption directory from the environment variable
consumptionDir := os.Getenv("PAPERLESS_CONSUMPTION_DIR")
if consumptionDir == "" {
debugPrint("PAPERLESS_CONSUMPTION_DIR is not set.")
fmt.Print(failedAuthFatalError)
os.Exit(1)
handleError("PAPERLESS_CONSUMPTION_DIR is not set", failedAuthFatalError)
}

// Print the successful authentication response
fmt.Printf(successfulAuthResponse, uid, gid, consumptionDir)
os.Exit(0)
} else {
debugPrint("Password verification failed.")
os.Exit(1)
handleError("Password verification failed", failedAuthFatalError)
}
} else {
debugPrint("User not found in the database.")
handleError("User not found in the database", failedAuthFatalError)
}
}

func getDatabaseConfig() (string, string, string, string, string, string) {
return os.Getenv("PAPERLESS_DBHOST"), os.Getenv("PAPERLESS_DBPORT"),
os.Getenv("PAPERLESS_DBNAME"), os.Getenv("DB_USER"),
os.Getenv("PAPERLESS_DBPASS"), os.Getenv("PAPERLESS_DBENGINE")
}

func openDatabaseConnection(engine, host, port, user, password, name string) (*sql.DB, error) {
var connStr string
switch engine {
case "postgres":
connStr = fmt.Sprintf("host=%s port=%s user=%s password=%s dbname=%s sslmode=disable", host, port, user, password, name)
case "mysql", "mariadb":
connStr = fmt.Sprintf("%s:%s@tcp(%s:%s)/%s", user, password, host, port, name)
default:
debugPrint(fmt.Sprintf("Unsupported database engine: %s", engine))
fmt.Print(failedAuthFatalError)
os.Exit(1)
}

return sql.Open(engine, connStr)
}

func getDatabaseQuery(engine, username string) (string, []interface{}) {
var query string
var queryArgs []interface{}

switch engine {
case "postgres":
query = "SELECT username, password FROM auth_user WHERE username = $1"
queryArgs = []interface{}{username}
case "mysql", "mariadb":
query = "SELECT username, password FROM auth_user WHERE username = ?"
queryArgs = []interface{}{username}
default:
debugPrint(fmt.Sprintf("Unsupported database engine: %s", engine))
fmt.Print(failedAuthFatalError)
os.Exit(1)
}

return query, queryArgs
}

func getCurrentUserIDs() (int, int) {
currentUser, err := user.Current()
if err != nil {
handleError("Failed to get current user information", failedAuthFatalError)
}

uid, err := strconv.Atoi(currentUser.Uid)
if err != nil {
handleError("Failed to convert UID to integer", failedAuthFatalError)
}

gid, err := strconv.Atoi(currentUser.Gid)
if err != nil {
handleError("Failed to convert GID to integer", failedAuthFatalError)
}

return uid, gid
}

func handleError(errorMessage, exitMessage string) {
debugPrint(errorMessage)
fmt.Print(exitMessage)
os.Exit(1)
}

0 comments on commit 0007771

Please sign in to comment.