advsql/db.go

73 lines
1.3 KiB
Go
Raw Normal View History

2022-07-05 12:38:39 +02:00
package advsql
import (
"database/sql"
2022-07-12 21:05:45 +02:00
"fmt"
2022-07-05 12:38:39 +02:00
_ "github.com/go-sql-driver/mysql"
)
type Database struct {
db *sql.DB
stmts map[string]*sql.Stmt
2022-07-05 12:38:39 +02:00
}
func InitDatabase(database *Database, conn *sql.DB) {
panicOnNil(database)
if database.db != nil {
panic("database is initialized already")
}
*database = Database{
db: conn,
stmts: map[string]*sql.Stmt{},
2022-07-08 22:09:53 +02:00
}
database.init()
2022-07-08 22:09:53 +02:00
}
func InitMysqlDatabase(database *Database, host string, port uint16, user, pass, db string) {
panicOnNil(database)
2022-07-05 12:38:39 +02:00
conn, err := sql.Open("mysql", connString(host, port, user, pass, db))
if err != nil {
panic(err)
2022-07-05 12:38:39 +02:00
}
InitDatabase(database, conn)
}
func panicOnNil(database *Database) {
if database == nil {
panic("database is nil. initialize database variable with new()")
}
2022-07-05 12:38:39 +02:00
}
func (db *Database) init() {
for _, globalQuery := range globalStmts {
if db == globalQuery.db {
db.stmt(globalQuery.query)
}
}
}
func (db *Database) stmt(query string) *sql.Stmt {
if stmt, ok := db.stmts[query]; ok {
return stmt
}
stmt, err := db.db.Prepare(query)
2022-07-05 12:38:39 +02:00
if err != nil {
2022-07-12 21:05:45 +02:00
panic(fmt.Errorf("compilation failed for query '%s' reason: %w", query, err))
2022-07-05 12:38:39 +02:00
}
db.stmts[query] = stmt
return stmt
2022-07-05 12:38:39 +02:00
}
func (db *Database) Close() error {
for _, stmt := range db.stmts {
stmt.Close()
2022-07-05 12:38:39 +02:00
}
return db.db.Close()
}