replaced Database constructors with init functions
This commit is contained in:
parent
3eac777080
commit
b2e2f96eb0
48
db.go
48
db.go
@ -8,36 +8,60 @@ import (
|
||||
|
||||
type Database struct {
|
||||
db *sql.DB
|
||||
closefuncs []func() error
|
||||
stmts map[string]*sql.Stmt
|
||||
}
|
||||
|
||||
func NewDatabase(conn *sql.DB) *Database {
|
||||
return &Database{
|
||||
func InitDatabase(database *Database, conn *sql.DB) {
|
||||
panicOnNil(database)
|
||||
|
||||
*database = Database{
|
||||
db: conn,
|
||||
closefuncs: make([]func() error, 0),
|
||||
stmts: map[string]*sql.Stmt{},
|
||||
}
|
||||
database.init()
|
||||
}
|
||||
|
||||
func NewMysqlDatabase(host string, port uint16, user, pass, db string) *Database {
|
||||
func InitMysqlDatabase(database *Database, host string, port uint16, user, pass, db string) {
|
||||
panicOnNil(database)
|
||||
|
||||
conn, err := sql.Open("mysql", connString(host, port, user, pass, db))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return NewDatabase(conn)
|
||||
|
||||
InitDatabase(database, conn)
|
||||
}
|
||||
|
||||
func (db *Database) prepare(query string) *sql.Stmt {
|
||||
s, err := db.db.Prepare(query)
|
||||
func panicOnNil(database *Database) {
|
||||
if database == nil {
|
||||
panic("database is nil. initialize database variable with new()")
|
||||
}
|
||||
}
|
||||
|
||||
func (db *Database) init() {
|
||||
for _, globalQuery := range globalStmts {
|
||||
if db == globalQuery.db {
|
||||
db.stmt(globalQuery.query)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (db *Database) stmt(query string) *sql.Stmt {
|
||||
if stmt, ok := db.stmts[query]; ok {
|
||||
return stmt
|
||||
}
|
||||
|
||||
stmt, err := db.db.Prepare(query)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
db.closefuncs = append(db.closefuncs, s.Close)
|
||||
return s
|
||||
db.stmts[query] = stmt
|
||||
return stmt
|
||||
}
|
||||
|
||||
func (db *Database) Close() error {
|
||||
for _, close := range db.closefuncs {
|
||||
close()
|
||||
for _, stmt := range db.stmts {
|
||||
stmt.Close()
|
||||
}
|
||||
return db.db.Close()
|
||||
}
|
||||
|
36
db_test.go
36
db_test.go
@ -9,33 +9,31 @@ import (
|
||||
_ "github.com/go-sql-driver/mysql"
|
||||
)
|
||||
|
||||
var TestDatabase *Database //= new(Database)
|
||||
|
||||
type User struct {
|
||||
Name string
|
||||
Hash []byte
|
||||
Salt string
|
||||
}
|
||||
|
||||
func UserDecoder(u *User, decode ScanFunc) error {
|
||||
return decode(&u.Name, &u.Hash, &u.Salt)
|
||||
func ScanUserPkFirst(u *User, encode ScanFunc) error {
|
||||
return encode(&u.Name, &u.Hash, &u.Salt)
|
||||
}
|
||||
|
||||
func InsertUserEncoder(u *User, encode ScanFunc) error {
|
||||
return encode(u.Name, u.Hash, u.Salt)
|
||||
func ScanUserPkLast(u *User, encode ScanFunc) error {
|
||||
return encode(&u.Hash, &u.Salt, &u.Name)
|
||||
}
|
||||
|
||||
func UpdateUserByNameEncoder(u *User, encode ScanFunc) error {
|
||||
return encode(u.Salt, u.Name)
|
||||
}
|
||||
var (
|
||||
InsertUser = Insert(TestDatabase, "INSERT INTO users VALUES (?, ?, ?)", ScanUserPkFirst)
|
||||
UpdateUser = Update(TestDatabase, "UPDATE users SET hash = ?, salt = ? WHERE name = ?", ScanUserPkLast)
|
||||
GetUserByName = QueryOne(TestDatabase, "SELECT * FROM users WHERE namea = ?", ScanUserPkFirst)
|
||||
)
|
||||
|
||||
func TestDB(t *testing.T) {
|
||||
db := NewMysqlDatabase("ip", 3306, "username", "password", "database")
|
||||
defer db.Close()
|
||||
|
||||
insertUser := Insert(db, "INSERT INTO users VALUES (?, ?, ?)", InsertUserEncoder)
|
||||
|
||||
updateUser := Update(db, "UPDATE users SET salt = ? WHERE name = ?", UpdateUserByNameEncoder)
|
||||
|
||||
getUsers := QueryMany(db, "SELECT * FROM users WHERE name = ?", UserDecoder)
|
||||
InitMysqlDatabase(TestDatabase, "ip", 3306, "username", "password", "database")
|
||||
defer TestDatabase.Close()
|
||||
|
||||
pw := sha512.Sum512([]byte("weiter"))
|
||||
timon := &User{
|
||||
@ -43,13 +41,11 @@ func TestDB(t *testing.T) {
|
||||
Hash: pw[:],
|
||||
Salt: "salt",
|
||||
}
|
||||
fmt.Println("insert:", insertUser(timon))
|
||||
fmt.Println("insert:", InsertUser(timon))
|
||||
|
||||
timon.Hash = []byte("asd")
|
||||
fmt.Println("update:", updateUser(timon))
|
||||
fmt.Println("update:", UpdateUser(timon))
|
||||
|
||||
for user := range getUsers("tordarus") {
|
||||
user := GetUserByName("tordarus")
|
||||
fmt.Printf("name: \"%s\" | hash: \"%s\" | salt: \"%s\"\n", user.Name, hex.EncodeToString(user.Hash), user.Salt)
|
||||
}
|
||||
|
||||
}
|
||||
|
12
global_stmts.go
Normal file
12
global_stmts.go
Normal file
@ -0,0 +1,12 @@
|
||||
package advsql
|
||||
|
||||
type globalStmt struct {
|
||||
db *Database
|
||||
query string
|
||||
}
|
||||
|
||||
var globalStmts = []*globalStmt{}
|
||||
|
||||
func prepareGlobal(db *Database, query string) {
|
||||
globalStmts = append(globalStmts, &globalStmt{db, query})
|
||||
}
|
@ -1,9 +1,10 @@
|
||||
package advsql
|
||||
|
||||
func Insert[T any](db *Database, query string, encoder func(v *T, encode ScanFunc) error) InsertFunc[T] {
|
||||
s := db.prepare(query)
|
||||
prepareGlobal(db, query)
|
||||
|
||||
return func(value *T) error {
|
||||
s := db.stmt(query)
|
||||
return encoder(value, func(args ...interface{}) error {
|
||||
_, err := s.Exec(args...)
|
||||
return err
|
||||
|
@ -10,13 +10,11 @@ func QueryMany[T any](db *Database, query string, decoder func(v *T, decode Scan
|
||||
}
|
||||
|
||||
func QueryManyContext[T any](db *Database, query string, decoder func(v *T, decode ScanFunc) error) QueryManyContextFunc[T] {
|
||||
s, err := db.db.Prepare(query)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
db.closefuncs = append(db.closefuncs, s.Close)
|
||||
prepareGlobal(db, query)
|
||||
|
||||
return func(ctx context.Context, args ...interface{}) <-chan *T {
|
||||
s := db.stmt(query)
|
||||
|
||||
out := make(chan *T, 10)
|
||||
|
||||
rows, err := s.QueryContext(ctx, args...)
|
||||
|
Loading…
Reference in New Issue
Block a user