advsql/db.go
2022-07-12 21:05:45 +02:00

69 lines
1.2 KiB
Go

package advsql
import (
"database/sql"
"fmt"
_ "github.com/go-sql-driver/mysql"
)
type Database struct {
db *sql.DB
stmts map[string]*sql.Stmt
}
func InitDatabase(database *Database, conn *sql.DB) {
panicOnNil(database)
*database = Database{
db: conn,
stmts: map[string]*sql.Stmt{},
}
database.init()
}
func InitMysqlDatabase(database *Database, host string, port uint16, user, pass, db string) {
panicOnNil(database)
conn, err := sql.Open("mysql", connString(host, port, user, pass, db))
if err != nil {
panic(err)
}
InitDatabase(database, conn)
}
func panicOnNil(database *Database) {
if database == nil {
panic("database is nil. initialize database variable with new()")
}
}
func (db *Database) init() {
for _, globalQuery := range globalStmts {
if db == globalQuery.db {
db.stmt(globalQuery.query)
}
}
}
func (db *Database) stmt(query string) *sql.Stmt {
if stmt, ok := db.stmts[query]; ok {
return stmt
}
stmt, err := db.db.Prepare(query)
if err != nil {
panic(fmt.Errorf("compilation failed for query '%s' reason: %w", query, err))
}
db.stmts[query] = stmt
return stmt
}
func (db *Database) Close() error {
for _, stmt := range db.stmts {
stmt.Close()
}
return db.db.Close()
}