diff --git a/db.go b/db.go index 9bc5907..5edd74d 100644 --- a/db.go +++ b/db.go @@ -7,37 +7,61 @@ import ( ) type Database struct { - db *sql.DB - closefuncs []func() error + db *sql.DB + stmts map[string]*sql.Stmt } -func NewDatabase(conn *sql.DB) *Database { - return &Database{ - db: conn, - closefuncs: make([]func() error, 0), +func InitDatabase(database *Database, conn *sql.DB) { + panicOnNil(database) + + *database = Database{ + db: conn, + stmts: map[string]*sql.Stmt{}, } + database.init() } -func NewMysqlDatabase(host string, port uint16, user, pass, db string) *Database { +func InitMysqlDatabase(database *Database, host string, port uint16, user, pass, db string) { + panicOnNil(database) + conn, err := sql.Open("mysql", connString(host, port, user, pass, db)) if err != nil { panic(err) } - return NewDatabase(conn) + + InitDatabase(database, conn) } -func (db *Database) prepare(query string) *sql.Stmt { - s, err := db.db.Prepare(query) +func panicOnNil(database *Database) { + if database == nil { + panic("database is nil. initialize database variable with new()") + } +} + +func (db *Database) init() { + for _, globalQuery := range globalStmts { + if db == globalQuery.db { + db.stmt(globalQuery.query) + } + } +} + +func (db *Database) stmt(query string) *sql.Stmt { + if stmt, ok := db.stmts[query]; ok { + return stmt + } + + stmt, err := db.db.Prepare(query) if err != nil { panic(err) } - db.closefuncs = append(db.closefuncs, s.Close) - return s + db.stmts[query] = stmt + return stmt } func (db *Database) Close() error { - for _, close := range db.closefuncs { - close() + for _, stmt := range db.stmts { + stmt.Close() } return db.db.Close() } diff --git a/db_test.go b/db_test.go index 6e9b49a..4ecb870 100644 --- a/db_test.go +++ b/db_test.go @@ -9,33 +9,31 @@ import ( _ "github.com/go-sql-driver/mysql" ) +var TestDatabase *Database //= new(Database) + type User struct { Name string Hash []byte Salt string } -func UserDecoder(u *User, decode ScanFunc) error { - return decode(&u.Name, &u.Hash, &u.Salt) +func ScanUserPkFirst(u *User, encode ScanFunc) error { + return encode(&u.Name, &u.Hash, &u.Salt) } -func InsertUserEncoder(u *User, encode ScanFunc) error { - return encode(u.Name, u.Hash, u.Salt) +func ScanUserPkLast(u *User, encode ScanFunc) error { + return encode(&u.Hash, &u.Salt, &u.Name) } -func UpdateUserByNameEncoder(u *User, encode ScanFunc) error { - return encode(u.Salt, u.Name) -} +var ( + InsertUser = Insert(TestDatabase, "INSERT INTO users VALUES (?, ?, ?)", ScanUserPkFirst) + UpdateUser = Update(TestDatabase, "UPDATE users SET hash = ?, salt = ? WHERE name = ?", ScanUserPkLast) + GetUserByName = QueryOne(TestDatabase, "SELECT * FROM users WHERE namea = ?", ScanUserPkFirst) +) func TestDB(t *testing.T) { - db := NewMysqlDatabase("ip", 3306, "username", "password", "database") - defer db.Close() - - insertUser := Insert(db, "INSERT INTO users VALUES (?, ?, ?)", InsertUserEncoder) - - updateUser := Update(db, "UPDATE users SET salt = ? WHERE name = ?", UpdateUserByNameEncoder) - - getUsers := QueryMany(db, "SELECT * FROM users WHERE name = ?", UserDecoder) + InitMysqlDatabase(TestDatabase, "ip", 3306, "username", "password", "database") + defer TestDatabase.Close() pw := sha512.Sum512([]byte("weiter")) timon := &User{ @@ -43,13 +41,11 @@ func TestDB(t *testing.T) { Hash: pw[:], Salt: "salt", } - fmt.Println("insert:", insertUser(timon)) + fmt.Println("insert:", InsertUser(timon)) timon.Hash = []byte("asd") - fmt.Println("update:", updateUser(timon)) - - for user := range getUsers("tordarus") { - fmt.Printf("name: \"%s\" | hash: \"%s\" | salt: \"%s\"\n", user.Name, hex.EncodeToString(user.Hash), user.Salt) - } + fmt.Println("update:", UpdateUser(timon)) + user := GetUserByName("tordarus") + fmt.Printf("name: \"%s\" | hash: \"%s\" | salt: \"%s\"\n", user.Name, hex.EncodeToString(user.Hash), user.Salt) } diff --git a/global_stmts.go b/global_stmts.go new file mode 100644 index 0000000..ec107c9 --- /dev/null +++ b/global_stmts.go @@ -0,0 +1,12 @@ +package advsql + +type globalStmt struct { + db *Database + query string +} + +var globalStmts = []*globalStmt{} + +func prepareGlobal(db *Database, query string) { + globalStmts = append(globalStmts, &globalStmt{db, query}) +} diff --git a/insert.go b/insert.go index 87ed181..239e4ee 100644 --- a/insert.go +++ b/insert.go @@ -1,9 +1,10 @@ package advsql func Insert[T any](db *Database, query string, encoder func(v *T, encode ScanFunc) error) InsertFunc[T] { - s := db.prepare(query) + prepareGlobal(db, query) return func(value *T) error { + s := db.stmt(query) return encoder(value, func(args ...interface{}) error { _, err := s.Exec(args...) return err diff --git a/query_many.go b/query_many.go index 80ceeae..579c3ee 100644 --- a/query_many.go +++ b/query_many.go @@ -10,13 +10,11 @@ func QueryMany[T any](db *Database, query string, decoder func(v *T, decode Scan } func QueryManyContext[T any](db *Database, query string, decoder func(v *T, decode ScanFunc) error) QueryManyContextFunc[T] { - s, err := db.db.Prepare(query) - if err != nil { - return nil - } - db.closefuncs = append(db.closefuncs, s.Close) + prepareGlobal(db, query) return func(ctx context.Context, args ...interface{}) <-chan *T { + s := db.stmt(query) + out := make(chan *T, 10) rows, err := s.QueryContext(ctx, args...)