Compare commits
11 Commits
Author | SHA1 | Date | |
---|---|---|---|
![]() |
31258a50a4 | ||
![]() |
2ce3d755bd | ||
![]() |
f7bae81df7 | ||
![]() |
3f8e334dbb | ||
![]() |
6d24df4153 | ||
![]() |
09c2d07e54 | ||
![]() |
e39d4c3a13 | ||
![]() |
96484eba5a | ||
![]() |
b2e2f96eb0 | ||
![]() |
3eac777080 | ||
![]() |
e9402c2f4e |
61
db.go
61
db.go
@ -2,42 +2,71 @@ package advsql
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
_ "github.com/go-sql-driver/mysql"
|
_ "github.com/go-sql-driver/mysql"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Database struct {
|
type Database struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
closefuncs []func() error
|
stmts map[string]*sql.Stmt
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewDatabase(conn *sql.DB) *Database {
|
func InitDatabase(database *Database, conn *sql.DB) {
|
||||||
return &Database{
|
panicOnNil(database)
|
||||||
db: conn,
|
|
||||||
closefuncs: make([]func() error, 0),
|
if database.db != nil {
|
||||||
|
panic("database is initialized already")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
*database = Database{
|
||||||
|
db: conn,
|
||||||
|
stmts: map[string]*sql.Stmt{},
|
||||||
|
}
|
||||||
|
database.init()
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewMysqlDatabase(host string, port uint16, user, pass, db string) *Database {
|
func InitMysqlDatabase(database *Database, host string, port uint16, user, pass, db string) {
|
||||||
|
panicOnNil(database)
|
||||||
|
|
||||||
conn, err := sql.Open("mysql", connString(host, port, user, pass, db))
|
conn, err := sql.Open("mysql", connString(host, port, user, pass, db))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
return NewDatabase(conn)
|
|
||||||
|
InitDatabase(database, conn)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *Database) prepare(query string) *sql.Stmt {
|
func panicOnNil(database *Database) {
|
||||||
s, err := db.db.Prepare(query)
|
if database == nil {
|
||||||
if err != nil {
|
panic("database is nil. initialize database variable with new()")
|
||||||
panic(err)
|
|
||||||
}
|
}
|
||||||
db.closefuncs = append(db.closefuncs, s.Close)
|
}
|
||||||
return s
|
|
||||||
|
func (db *Database) init() {
|
||||||
|
for _, globalQuery := range globalStmts {
|
||||||
|
if db == globalQuery.db {
|
||||||
|
db.stmt(globalQuery.query)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (db *Database) stmt(query string) *sql.Stmt {
|
||||||
|
if stmt, ok := db.stmts[query]; ok {
|
||||||
|
return stmt
|
||||||
|
}
|
||||||
|
|
||||||
|
stmt, err := db.db.Prepare(query)
|
||||||
|
if err != nil {
|
||||||
|
panic(fmt.Errorf("compilation failed for query '%s' reason: %w", query, err))
|
||||||
|
}
|
||||||
|
db.stmts[query] = stmt
|
||||||
|
return stmt
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *Database) Close() error {
|
func (db *Database) Close() error {
|
||||||
for _, close := range db.closefuncs {
|
for _, stmt := range db.stmts {
|
||||||
close()
|
stmt.Close()
|
||||||
}
|
}
|
||||||
return db.db.Close()
|
return db.db.Close()
|
||||||
}
|
}
|
||||||
|
38
db_test.go
38
db_test.go
@ -9,33 +9,31 @@ import (
|
|||||||
_ "github.com/go-sql-driver/mysql"
|
_ "github.com/go-sql-driver/mysql"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var TestDatabase = new(Database)
|
||||||
|
|
||||||
type User struct {
|
type User struct {
|
||||||
Name string
|
Name string
|
||||||
Hash []byte
|
Hash []byte
|
||||||
Salt string
|
Salt string
|
||||||
}
|
}
|
||||||
|
|
||||||
func UserDecoder(u *User, decode DecodeFunc) error {
|
func ScanUserPkFirst(u *User, encode ScanFunc) error {
|
||||||
return decode(&u.Name, &u.Hash, &u.Salt)
|
return encode(&u.Name, &u.Hash, &u.Salt)
|
||||||
}
|
}
|
||||||
|
|
||||||
func InsertUserEncoder(u *User, encode EncodeFunc) error {
|
func ScanUserPkLast(u *User, encode ScanFunc) error {
|
||||||
return encode(u.Name, u.Hash, u.Salt)
|
return encode(&u.Hash, &u.Salt, &u.Name)
|
||||||
}
|
}
|
||||||
|
|
||||||
func UpdateUserByNameEncoder(u *User, encode EncodeFunc) error {
|
var (
|
||||||
return encode(u.Salt, u.Name)
|
InsertUser = Insert(TestDatabase, "INSERT INTO users VALUES (?, ?, ?)", ScanUserPkFirst)
|
||||||
}
|
UpdateUser = Update(TestDatabase, "UPDATE users SET hash = ?, salt = ? WHERE name = ?", ScanUserPkLast)
|
||||||
|
GetUserByName = QueryOne(TestDatabase, "SELECT * FROM users WHERE namea = ?", ScanUserPkFirst)
|
||||||
|
)
|
||||||
|
|
||||||
func TestDB(t *testing.T) {
|
func TestDB(t *testing.T) {
|
||||||
db := NewMysqlDatabase("ip", 3306, "username", "password", "database")
|
InitMysqlDatabase(TestDatabase, "ip", 3306, "username", "password", "database")
|
||||||
defer db.Close()
|
defer TestDatabase.Close()
|
||||||
|
|
||||||
insertUser := Insert(db, "INSERT INTO users VALUES (?, ?, ?)", InsertUserEncoder)
|
|
||||||
|
|
||||||
updateUser := Update(db, "UPDATE users SET salt = ? WHERE name = ?", UpdateUserByNameEncoder)
|
|
||||||
|
|
||||||
getUsers := QueryMany(db, "SELECT * FROM users WHERE name = ?", UserDecoder)
|
|
||||||
|
|
||||||
pw := sha512.Sum512([]byte("weiter"))
|
pw := sha512.Sum512([]byte("weiter"))
|
||||||
timon := &User{
|
timon := &User{
|
||||||
@ -43,13 +41,11 @@ func TestDB(t *testing.T) {
|
|||||||
Hash: pw[:],
|
Hash: pw[:],
|
||||||
Salt: "salt",
|
Salt: "salt",
|
||||||
}
|
}
|
||||||
fmt.Println("insert:", insertUser(timon))
|
fmt.Println("insert:", InsertUser(timon))
|
||||||
|
|
||||||
timon.Hash = []byte("asd")
|
timon.Hash = []byte("asd")
|
||||||
fmt.Println("update:", updateUser(timon))
|
fmt.Println("update:", UpdateUser(timon))
|
||||||
|
|
||||||
for user := range getUsers("tordarus") {
|
|
||||||
fmt.Printf("name: \"%s\" | hash: \"%s\" | salt: \"%s\"\n", user.Name, hex.EncodeToString(user.Hash), user.Salt)
|
|
||||||
}
|
|
||||||
|
|
||||||
|
user := GetUserByName("tordarus")
|
||||||
|
fmt.Printf("name: \"%s\" | hash: \"%s\" | salt: \"%s\"\n", user.Name, hex.EncodeToString(user.Hash), user.Salt)
|
||||||
}
|
}
|
||||||
|
@ -7,6 +7,6 @@ type defaultDecoderType interface {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Decoder provides default decoders for primitive datatypes
|
// Decoder provides default decoders for primitive datatypes
|
||||||
func Decoder[T defaultDecoderType](value *T, decode DecodeFunc) error {
|
func Decoder[T defaultDecoderType](value *T, decode ScanFunc) error {
|
||||||
return decode(value)
|
return decode(value)
|
||||||
}
|
}
|
||||||
|
11
delete.go
Normal file
11
delete.go
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
package advsql
|
||||||
|
|
||||||
|
func Delete(db *Database, query string) DeleteFunc {
|
||||||
|
prepareGlobal(db, query)
|
||||||
|
|
||||||
|
return func(args ...interface{}) error {
|
||||||
|
s := db.stmt(query)
|
||||||
|
_, err := s.Exec(args...)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
@ -10,3 +10,7 @@ type QueryOneContextFunc[T any] func(ctx context.Context, args ...interface{}) *
|
|||||||
|
|
||||||
type InsertFunc[T any] func(v *T) error
|
type InsertFunc[T any] func(v *T) error
|
||||||
type UpdateFunc[T any] func(v *T) error
|
type UpdateFunc[T any] func(v *T) error
|
||||||
|
|
||||||
|
type DeleteFunc func(args ...interface{}) error
|
||||||
|
|
||||||
|
type ScanFunc = func(args ...interface{}) error
|
||||||
|
12
global_stmts.go
Normal file
12
global_stmts.go
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
package advsql
|
||||||
|
|
||||||
|
type globalStmt struct {
|
||||||
|
db *Database
|
||||||
|
query string
|
||||||
|
}
|
||||||
|
|
||||||
|
var globalStmts = []*globalStmt{}
|
||||||
|
|
||||||
|
func prepareGlobal(db *Database, query string) {
|
||||||
|
globalStmts = append(globalStmts, &globalStmt{db, query})
|
||||||
|
}
|
7
go.mod
7
go.mod
@ -1,8 +1,5 @@
|
|||||||
module git.tordarus.net/Tordarus/advsql
|
module git.milar.in/milarin/advsql
|
||||||
|
|
||||||
go 1.18
|
go 1.18
|
||||||
|
|
||||||
require (
|
require github.com/go-sql-driver/mysql v1.6.0
|
||||||
git.tordarus.net/Tordarus/adverr v0.2.0
|
|
||||||
github.com/go-sql-driver/mysql v1.6.0
|
|
||||||
)
|
|
||||||
|
2
go.sum
2
go.sum
@ -1,4 +1,2 @@
|
|||||||
git.tordarus.net/Tordarus/adverr v0.2.0 h1:kLYjR2/Vb2GHiSAMvAv+WPNaHR9BRphKanf8H/pCZdA=
|
|
||||||
git.tordarus.net/Tordarus/adverr v0.2.0/go.mod h1:XRf0+7nhOkIEr0gi9DUG4RvV2KaOFB0fYPDaR1KLenw=
|
|
||||||
github.com/go-sql-driver/mysql v1.6.0 h1:BCTh4TKNUYmOmMUcQ3IipzF5prigylS7XXjEkfCHuOE=
|
github.com/go-sql-driver/mysql v1.6.0 h1:BCTh4TKNUYmOmMUcQ3IipzF5prigylS7XXjEkfCHuOE=
|
||||||
github.com/go-sql-driver/mysql v1.6.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg=
|
github.com/go-sql-driver/mysql v1.6.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg=
|
||||||
|
@ -1,9 +1,10 @@
|
|||||||
package advsql
|
package advsql
|
||||||
|
|
||||||
func Insert[T any](db *Database, query string, encoder func(v *T, encode EncodeFunc) error) InsertFunc[T] {
|
func Insert[T any](db *Database, query string, encoder func(v *T, encode ScanFunc) error) InsertFunc[T] {
|
||||||
s := db.prepare(query)
|
prepareGlobal(db, query)
|
||||||
|
|
||||||
return func(value *T) error {
|
return func(value *T) error {
|
||||||
|
s := db.stmt(query)
|
||||||
return encoder(value, func(args ...interface{}) error {
|
return encoder(value, func(args ...interface{}) error {
|
||||||
_, err := s.Exec(args...)
|
_, err := s.Exec(args...)
|
||||||
return err
|
return err
|
||||||
|
@ -1,3 +0,0 @@
|
|||||||
# advsql
|
|
||||||
|
|
||||||
A more advanced ORM for the Golang sql package
|
|
@ -1,15 +0,0 @@
|
|||||||
package advsql
|
|
||||||
|
|
||||||
func (s *Store) checkTableIntegrity(value interface{}) (bool, error) {
|
|
||||||
valueStructure, err := objectStructure(value)
|
|
||||||
if err != nil {
|
|
||||||
return false, err
|
|
||||||
}
|
|
||||||
|
|
||||||
tableStructure, err := s.tableStructure(valueStructure.TypeName)
|
|
||||||
if err != nil {
|
|
||||||
return false, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return compareStructures(valueStructure, tableStructure), nil
|
|
||||||
}
|
|
@ -1,34 +0,0 @@
|
|||||||
package advsql
|
|
||||||
|
|
||||||
import (
|
|
||||||
"log"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
_ "github.com/go-sql-driver/mysql"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestCheckTableIntegrity(t *testing.T) {
|
|
||||||
store, err := NewStoreMySQL("localhost", 3306, "kaikei", "u399D-TFulJykgRF4ijW6G7tK6zdO-jy", "kaikei")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal("could not make store", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
type Person struct {
|
|
||||||
Surname string
|
|
||||||
Lastname string
|
|
||||||
Age int
|
|
||||||
}
|
|
||||||
|
|
||||||
p := &Person{
|
|
||||||
Surname: "Timon",
|
|
||||||
Lastname: "Ringwald",
|
|
||||||
Age: 25,
|
|
||||||
}
|
|
||||||
|
|
||||||
integrity, err := store.checkTableIntegrity(p)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal("could not check table integrity", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Println("integrity:", integrity)
|
|
||||||
}
|
|
@ -1,96 +0,0 @@
|
|||||||
package advsql
|
|
||||||
|
|
||||||
import (
|
|
||||||
"log"
|
|
||||||
"reflect"
|
|
||||||
"strings"
|
|
||||||
)
|
|
||||||
|
|
||||||
func (s *Store) createTable(structure *structure) error {
|
|
||||||
b := new(strings.Builder)
|
|
||||||
b.WriteString("CREATE TABLE ")
|
|
||||||
b.WriteString(structure.TypeName)
|
|
||||||
b.WriteString("(")
|
|
||||||
|
|
||||||
fieldAmount := len(structure.Fields)
|
|
||||||
|
|
||||||
for i, field := range structure.Fields {
|
|
||||||
b.WriteString(field.FieldName)
|
|
||||||
b.WriteString(" ")
|
|
||||||
b.WriteString(getTypeSpecForDB(field.FieldType))
|
|
||||||
|
|
||||||
if i < fieldAmount-1 {
|
|
||||||
b.WriteString(",")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
b.WriteString(")")
|
|
||||||
|
|
||||||
log.Println("sql command: " + b.String())
|
|
||||||
|
|
||||||
_, err := s.db.Exec(b.String())
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func getTypeSpecForDB(t reflect.Type) string {
|
|
||||||
noPtrType := t
|
|
||||||
dereferenced := false
|
|
||||||
|
|
||||||
// derefence all pointers recursively
|
|
||||||
for noPtrType.Kind() == reflect.Ptr {
|
|
||||||
noPtrType = t.Elem()
|
|
||||||
dereferenced = true
|
|
||||||
}
|
|
||||||
|
|
||||||
// actual data types
|
|
||||||
b := new(strings.Builder)
|
|
||||||
switch noPtrType.Kind() {
|
|
||||||
case reflect.Bool:
|
|
||||||
b.WriteString("INT")
|
|
||||||
|
|
||||||
case reflect.Int:
|
|
||||||
b.WriteString("BIGINT")
|
|
||||||
case reflect.Int8:
|
|
||||||
b.WriteString("TINYINT")
|
|
||||||
case reflect.Int16:
|
|
||||||
b.WriteString("SMALLINT")
|
|
||||||
case reflect.Int32:
|
|
||||||
b.WriteString("INT")
|
|
||||||
case reflect.Int64:
|
|
||||||
b.WriteString("BIGINT")
|
|
||||||
|
|
||||||
case reflect.Uint:
|
|
||||||
b.WriteString("BIGINT UNSIGNED")
|
|
||||||
case reflect.Uint8:
|
|
||||||
b.WriteString("TINYINT UNSIGNED")
|
|
||||||
case reflect.Uint16:
|
|
||||||
b.WriteString("SMALLINT UNSIGNED")
|
|
||||||
case reflect.Uint32:
|
|
||||||
b.WriteString("INT UNSIGNED")
|
|
||||||
case reflect.Uint64:
|
|
||||||
b.WriteString("BIGINT UNSIGNED")
|
|
||||||
|
|
||||||
case reflect.Float32:
|
|
||||||
b.WriteString("FLOAT")
|
|
||||||
case reflect.Float64:
|
|
||||||
b.WriteString("DOUBLE")
|
|
||||||
|
|
||||||
case reflect.String:
|
|
||||||
b.WriteString("TEXT")
|
|
||||||
|
|
||||||
case reflect.Ptr:
|
|
||||||
panic("something's broken. All pointers should have been dereferenced by now")
|
|
||||||
default:
|
|
||||||
panic("type spec for db not implemented yet: " + noPtrType.Kind().String())
|
|
||||||
}
|
|
||||||
|
|
||||||
if !dereferenced {
|
|
||||||
b.WriteString(" NOT NULL")
|
|
||||||
}
|
|
||||||
|
|
||||||
return b.String()
|
|
||||||
}
|
|
@ -1,34 +0,0 @@
|
|||||||
package advsql
|
|
||||||
|
|
||||||
import (
|
|
||||||
"testing"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestCreateTable(t *testing.T) {
|
|
||||||
store, err := NewStoreMySQL("localhost", 3306, "kaikei", "u399D-TFulJykgRF4ijW6G7tK6zdO-jy", "kaikei")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal("could not make store", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
type Person struct {
|
|
||||||
Surname string
|
|
||||||
Lastname string
|
|
||||||
Age int
|
|
||||||
}
|
|
||||||
|
|
||||||
p := &Person{
|
|
||||||
Surname: "Timon",
|
|
||||||
Lastname: "Ringwald",
|
|
||||||
Age: 25,
|
|
||||||
}
|
|
||||||
|
|
||||||
structure, err := objectStructure(p)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal("Could not get structure of value", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
err = store.createTable(structure)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal("Could not create table", err)
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,11 +0,0 @@
|
|||||||
package advsql
|
|
||||||
|
|
||||||
func (s *Store) Insert(value interface{}) error {
|
|
||||||
// structure, err := objectStructure(value)
|
|
||||||
// if err != nil {
|
|
||||||
// return err
|
|
||||||
// }
|
|
||||||
|
|
||||||
// return nil
|
|
||||||
return nil
|
|
||||||
}
|
|
@ -1,27 +0,0 @@
|
|||||||
package advsql
|
|
||||||
|
|
||||||
import "testing"
|
|
||||||
|
|
||||||
func TestInsert(t *testing.T) {
|
|
||||||
store, err := NewStoreMySQL("localhost", 3306, "kaikei", "u399D-TFulJykgRF4ijW6G7tK6zdO-jy", "kaikei")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal("could not make store", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
type Person struct {
|
|
||||||
Surname string
|
|
||||||
Lastname string
|
|
||||||
Age int
|
|
||||||
}
|
|
||||||
|
|
||||||
p := &Person{
|
|
||||||
Surname: "Timon",
|
|
||||||
Lastname: "Ringwald",
|
|
||||||
Age: 25,
|
|
||||||
}
|
|
||||||
|
|
||||||
err = store.Insert(p)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal("Could not create table", err)
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,35 +0,0 @@
|
|||||||
package advsql
|
|
||||||
|
|
||||||
import (
|
|
||||||
"reflect"
|
|
||||||
"errors"
|
|
||||||
)
|
|
||||||
|
|
||||||
var ErrNoStructure = errors.New("value is not of type structure")
|
|
||||||
|
|
||||||
// analyze returns the structure of value
|
|
||||||
func objectStructure(value interface{}) (*structure, error) {
|
|
||||||
t := reflect.TypeOf(value)
|
|
||||||
k := t.Kind()
|
|
||||||
|
|
||||||
// only structs are allowed to be stored in DB
|
|
||||||
// if pointer: dereference and check again via recursion
|
|
||||||
if k == reflect.Ptr {
|
|
||||||
v := reflect.ValueOf(value)
|
|
||||||
return objectStructure(reflect.Indirect(v).Interface())
|
|
||||||
} else if k != reflect.Struct {
|
|
||||||
return nil, ErrNoStructure
|
|
||||||
}
|
|
||||||
|
|
||||||
fieldAmount := t.NumField()
|
|
||||||
structure := new(structure)
|
|
||||||
structure.TypeName = t.Name()
|
|
||||||
structure.Fields = make([]field, fieldAmount)
|
|
||||||
|
|
||||||
for fieldIndex := 0; fieldIndex < fieldAmount; fieldIndex++ {
|
|
||||||
reflectField := t.Field(fieldIndex)
|
|
||||||
structure.Fields[fieldIndex] = makeFieldUsingReflect(reflectField)
|
|
||||||
}
|
|
||||||
|
|
||||||
return structure, nil
|
|
||||||
}
|
|
47
old/store.go
47
old/store.go
@ -1,47 +0,0 @@
|
|||||||
package advsql
|
|
||||||
|
|
||||||
import (
|
|
||||||
"database/sql"
|
|
||||||
"fmt"
|
|
||||||
)
|
|
||||||
|
|
||||||
type Store struct {
|
|
||||||
db *sql.DB
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewStoreMySQL(host string, port uint16, user, pass, db string) (*Store, error) {
|
|
||||||
dsn := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s", user, pass, host, port, db)
|
|
||||||
|
|
||||||
conn, err := sql.Open("mysql", dsn)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return NewStore(conn), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewStore(db *sql.DB) *Store {
|
|
||||||
return &Store{
|
|
||||||
db: db,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Store) Close() error {
|
|
||||||
stmts := []*sql.Stmt{}
|
|
||||||
|
|
||||||
var err error
|
|
||||||
|
|
||||||
for _, stmt := range stmts {
|
|
||||||
err = stmt.Close()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
err = s.db.Close()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
@ -1,29 +0,0 @@
|
|||||||
package advsql
|
|
||||||
|
|
||||||
func (s *Store) tableStructure(tableName string) (*structure, error) {
|
|
||||||
rows, err := s.db.Query("SELECT * FROM " + tableName)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
defer rows.Close()
|
|
||||||
|
|
||||||
types, err := rows.ColumnTypes()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
structure := new(structure)
|
|
||||||
structure.TypeName = tableName
|
|
||||||
|
|
||||||
columnAmount := len(types)
|
|
||||||
structure.Fields = make([]field, columnAmount)
|
|
||||||
|
|
||||||
for columnIndex, columType := range types {
|
|
||||||
structure.Fields[columnIndex] = field{
|
|
||||||
FieldName: columType.Name(),
|
|
||||||
FieldType: columType.ScanType(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return structure, nil
|
|
||||||
}
|
|
94
old/types.go
94
old/types.go
@ -1,94 +0,0 @@
|
|||||||
package advsql
|
|
||||||
|
|
||||||
import (
|
|
||||||
"log"
|
|
||||||
"reflect"
|
|
||||||
)
|
|
||||||
|
|
||||||
type structure struct {
|
|
||||||
TypeName string
|
|
||||||
Fields []field
|
|
||||||
}
|
|
||||||
|
|
||||||
type field struct {
|
|
||||||
FieldName string
|
|
||||||
FieldType reflect.Type
|
|
||||||
}
|
|
||||||
|
|
||||||
func makeFieldUsingReflect(reflectField reflect.StructField) field {
|
|
||||||
return field{
|
|
||||||
FieldName: reflectField.Name,
|
|
||||||
FieldType: reflectField.Type,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func compareStructures(first, second *structure) bool {
|
|
||||||
// compare names
|
|
||||||
if first.TypeName != second.TypeName {
|
|
||||||
log.Println("different names")
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
firstFieldAmount, secondFieldAmount := len(first.Fields), len(second.Fields)
|
|
||||||
|
|
||||||
// compare field amounts
|
|
||||||
if firstFieldAmount != secondFieldAmount {
|
|
||||||
log.Println("different field amounts")
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
// compare fields
|
|
||||||
for i := 0; i < firstFieldAmount; i++ {
|
|
||||||
firstField, secondField := first.Fields[i], second.Fields[i]
|
|
||||||
if !compareFields(firstField, secondField) {
|
|
||||||
log.Println("different field:", firstField, secondField)
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
func compareFields(first, second field) bool {
|
|
||||||
t1, t2 := first.FieldType, second.FieldType
|
|
||||||
|
|
||||||
// check if field names are equal
|
|
||||||
fieldNamesEqual := first.FieldName == second.FieldName
|
|
||||||
|
|
||||||
// exceptions for default comparison
|
|
||||||
switch true {
|
|
||||||
case checkBoth(first, second, isString, isByteSlice):
|
|
||||||
fallthrough
|
|
||||||
case isInt(first) && isInt(second):
|
|
||||||
fallthrough
|
|
||||||
case isUint(first) && isUint(second):
|
|
||||||
return fieldNamesEqual
|
|
||||||
|
|
||||||
default:
|
|
||||||
return t1 == t2
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TypeCheckFunc represents any function for type checking between Go and DB
|
|
||||||
type TypeCheckFunc func(field) bool
|
|
||||||
|
|
||||||
// checks if both check functions apply to either of the fields
|
|
||||||
func checkBoth(first, second field, firstCheck, secondCheck TypeCheckFunc) bool {
|
|
||||||
return firstCheck(first) && secondCheck(second) || firstCheck(second) && secondCheck(first)
|
|
||||||
}
|
|
||||||
|
|
||||||
func isString(field field) bool {
|
|
||||||
return field.FieldType.Kind() == reflect.String
|
|
||||||
}
|
|
||||||
|
|
||||||
func isInt(field field) bool {
|
|
||||||
return field.FieldType.Kind() == reflect.Int || field.FieldType.Kind() == reflect.Int64
|
|
||||||
}
|
|
||||||
|
|
||||||
func isUint(field field) bool {
|
|
||||||
return field.FieldType.Kind() == reflect.Uint || field.FieldType.Kind() == reflect.Uint64
|
|
||||||
}
|
|
||||||
|
|
||||||
func isByteSlice(field field) bool {
|
|
||||||
return field.FieldType.Kind() == reflect.Slice && field.FieldType.Elem().Kind() == reflect.Uint8
|
|
||||||
}
|
|
@ -2,21 +2,19 @@ package advsql
|
|||||||
|
|
||||||
import "context"
|
import "context"
|
||||||
|
|
||||||
func QueryMany[T any](db *Database, query string, decoder func(v *T, decode DecodeFunc) error) QueryManyFunc[T] {
|
func QueryMany[T any](db *Database, query string, decoder func(v *T, decode ScanFunc) error) QueryManyFunc[T] {
|
||||||
ctxfunc := QueryManyContext(db, query, decoder)
|
ctxfunc := QueryManyContext(db, query, decoder)
|
||||||
return func(args ...interface{}) <-chan *T {
|
return func(args ...interface{}) <-chan *T {
|
||||||
return ctxfunc(context.Background(), args...)
|
return ctxfunc(context.Background(), args...)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func QueryManyContext[T any](db *Database, query string, decoder func(v *T, decode DecodeFunc) error) QueryManyContextFunc[T] {
|
func QueryManyContext[T any](db *Database, query string, decoder func(v *T, decode ScanFunc) error) QueryManyContextFunc[T] {
|
||||||
s, err := db.db.Prepare(query)
|
prepareGlobal(db, query)
|
||||||
if err != nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
db.closefuncs = append(db.closefuncs, s.Close)
|
|
||||||
|
|
||||||
return func(ctx context.Context, args ...interface{}) <-chan *T {
|
return func(ctx context.Context, args ...interface{}) <-chan *T {
|
||||||
|
s := db.stmt(query)
|
||||||
|
|
||||||
out := make(chan *T, 10)
|
out := make(chan *T, 10)
|
||||||
|
|
||||||
rows, err := s.QueryContext(ctx, args...)
|
rows, err := s.QueryContext(ctx, args...)
|
||||||
|
@ -2,14 +2,14 @@ package advsql
|
|||||||
|
|
||||||
import "context"
|
import "context"
|
||||||
|
|
||||||
func QueryOne[T any](db *Database, query string, decoder func(v *T, decode DecodeFunc) error) QueryOneFunc[T] {
|
func QueryOne[T any](db *Database, query string, decoder func(v *T, decode ScanFunc) error) QueryOneFunc[T] {
|
||||||
ctxfunc := QueryOneContext(db, query, decoder)
|
ctxfunc := QueryOneContext(db, query, decoder)
|
||||||
return func(args ...interface{}) *T {
|
return func(args ...interface{}) *T {
|
||||||
return ctxfunc(context.Background(), args...)
|
return ctxfunc(context.Background(), args...)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func QueryOneContext[T any](db *Database, query string, decoder func(v *T, decode DecodeFunc) error) QueryOneContextFunc[T] {
|
func QueryOneContext[T any](db *Database, query string, decoder func(v *T, decode ScanFunc) error) QueryOneContextFunc[T] {
|
||||||
manyfunc := QueryManyContext(db, query, decoder)
|
manyfunc := QueryManyContext(db, query, decoder)
|
||||||
return func(ctx context.Context, args ...interface{}) *T {
|
return func(ctx context.Context, args ...interface{}) *T {
|
||||||
nctx, cancel := context.WithCancel(ctx)
|
nctx, cancel := context.WithCancel(ctx)
|
||||||
|
@ -1,4 +0,0 @@
|
|||||||
package advsql
|
|
||||||
|
|
||||||
type DecodeFunc func(args ...interface{}) error
|
|
||||||
type EncodeFunc func(args ...interface{}) error
|
|
59
stmt.go
59
stmt.go
@ -1,59 +0,0 @@
|
|||||||
package advsql
|
|
||||||
|
|
||||||
// import (
|
|
||||||
// "database/sql"
|
|
||||||
// )
|
|
||||||
|
|
||||||
// type Stmt[M any] struct {
|
|
||||||
// stmt *sql.Stmt
|
|
||||||
// scan func(Scanner) (*M, error)
|
|
||||||
// }
|
|
||||||
|
|
||||||
// func NewStmt[M any](db *Database, query string, scan func(s Scanner) (*M, error)) (*Stmt[M], error) {
|
|
||||||
// s, err := db.db.Prepare(query)
|
|
||||||
// if err != nil {
|
|
||||||
// return nil, err
|
|
||||||
// }
|
|
||||||
|
|
||||||
// stmt := Stmt[M]{
|
|
||||||
// stmt: s,
|
|
||||||
// scan: scan,
|
|
||||||
// }
|
|
||||||
// db.closefuncs = append(db.closefuncs, stmt.Close)
|
|
||||||
// return &stmt, err
|
|
||||||
// }
|
|
||||||
|
|
||||||
// func (stmt *Stmt[M]) Close() error {
|
|
||||||
// return stmt.stmt.Close()
|
|
||||||
// }
|
|
||||||
|
|
||||||
// func (stmt *Stmt[M]) Many(args ...interface{}) (<-chan *M, error) {
|
|
||||||
// rows, err := stmt.stmt.Query(args...)
|
|
||||||
// if err != nil {
|
|
||||||
// return nil, err
|
|
||||||
// }
|
|
||||||
|
|
||||||
// out := make(chan *M, 10)
|
|
||||||
|
|
||||||
// go func() {
|
|
||||||
// defer rows.Close()
|
|
||||||
// defer close(out)
|
|
||||||
|
|
||||||
// for rows.Next() {
|
|
||||||
// if v, err := stmt.scan(rows); err == nil {
|
|
||||||
// out <- v
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
// }()
|
|
||||||
|
|
||||||
// return out, nil
|
|
||||||
// }
|
|
||||||
|
|
||||||
// func (stmt *Stmt[M]) Single(args ...interface{}) (*M, error) {
|
|
||||||
// rows, err := stmt.stmt.Query(args...)
|
|
||||||
// if err != nil {
|
|
||||||
// return nil, err
|
|
||||||
// }
|
|
||||||
// defer rows.Close()
|
|
||||||
// return stmt.scan(rows)
|
|
||||||
// }
|
|
@ -1,5 +1,5 @@
|
|||||||
package advsql
|
package advsql
|
||||||
|
|
||||||
func Update[T any](db *Database, query string, encoder func(v *T, encode EncodeFunc) error) UpdateFunc[T] {
|
func Update[T any](db *Database, query string, encoder func(v *T, encode ScanFunc) error) UpdateFunc[T] {
|
||||||
return UpdateFunc[T](Insert(db, query, encoder))
|
return UpdateFunc[T](Insert(db, query, encoder))
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user