68 lines
1.1 KiB
Go
68 lines
1.1 KiB
Go
package advsql
|
|
|
|
import (
|
|
"database/sql"
|
|
|
|
_ "github.com/go-sql-driver/mysql"
|
|
)
|
|
|
|
type Database struct {
|
|
db *sql.DB
|
|
stmts map[string]*sql.Stmt
|
|
}
|
|
|
|
func InitDatabase(database *Database, conn *sql.DB) {
|
|
panicOnNil(database)
|
|
|
|
*database = Database{
|
|
db: conn,
|
|
stmts: map[string]*sql.Stmt{},
|
|
}
|
|
database.init()
|
|
}
|
|
|
|
func InitMysqlDatabase(database *Database, host string, port uint16, user, pass, db string) {
|
|
panicOnNil(database)
|
|
|
|
conn, err := sql.Open("mysql", connString(host, port, user, pass, db))
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
|
|
InitDatabase(database, conn)
|
|
}
|
|
|
|
func panicOnNil(database *Database) {
|
|
if database == nil {
|
|
panic("database is nil. initialize database variable with new()")
|
|
}
|
|
}
|
|
|
|
func (db *Database) init() {
|
|
for _, globalQuery := range globalStmts {
|
|
if db == globalQuery.db {
|
|
db.stmt(globalQuery.query)
|
|
}
|
|
}
|
|
}
|
|
|
|
func (db *Database) stmt(query string) *sql.Stmt {
|
|
if stmt, ok := db.stmts[query]; ok {
|
|
return stmt
|
|
}
|
|
|
|
stmt, err := db.db.Prepare(query)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
db.stmts[query] = stmt
|
|
return stmt
|
|
}
|
|
|
|
func (db *Database) Close() error {
|
|
for _, stmt := range db.stmts {
|
|
stmt.Close()
|
|
}
|
|
return db.db.Close()
|
|
}
|