Compare commits

..

18 Commits
v0.0.1 ... main

Author SHA1 Message Date
Timon Ringwald
31258a50a4 improved Delete func 2022-09-06 10:51:21 +02:00
Timon Ringwald
2ce3d755bd improved Delete 2022-09-06 10:49:22 +02:00
Timon Ringwald
f7bae81df7 added Delete 2022-09-06 10:42:09 +02:00
Timon Ringwald
3f8e334dbb panic when initializing already initialized database 2022-08-04 09:37:01 +02:00
Timon Ringwald
6d24df4153 moved to git.milar.in 2022-08-03 21:45:35 +02:00
Timon Ringwald
09c2d07e54 removed leaked password 2022-07-12 21:06:43 +02:00
Timon Ringwald
e39d4c3a13 show sql query on sql error 2022-07-12 21:05:45 +02:00
Timon Ringwald
96484eba5a removed SQL syntax error 2022-07-12 21:01:49 +02:00
Timon Ringwald
b2e2f96eb0 replaced Database constructors with init functions 2022-07-12 20:56:33 +02:00
Timon Ringwald
3eac777080 EncodeFunc and DecodeFunc replaced by ScanFunc 2022-07-12 17:56:06 +02:00
Timon Ringwald
e9402c2f4e EncodeFunc and DecodeFunc as aliases 2022-07-12 16:32:02 +02:00
Timon Ringwald
cfeb6940fb fixed errors 2022-07-11 13:28:07 +02:00
Timon Ringwald
a1e118f200 removed error return value from constructor 2022-07-11 13:05:16 +02:00
Timon Ringwald
a39f467b93 QueryOne implemented
Query renamed to QueryMany
Default Decoders refactored
2022-07-11 13:00:36 +02:00
Timon Ringwald
de1a975799 parse times from db into time.Time 2022-07-11 12:30:32 +02:00
Timon Ringwald
4346466345 panic on init error 2022-07-09 23:55:22 +02:00
Timon Ringwald
5f85c26d4a added default decoders 2022-07-08 22:54:33 +02:00
Timon Ringwald
b4c0849e29 renamed ExecFunc to EncodeFunc and ScanFunc to DecodeFunc 2022-07-08 22:25:05 +02:00
27 changed files with 172 additions and 609 deletions

62
db.go
View File

@ -2,43 +2,71 @@ package advsql
import (
"database/sql"
"fmt"
"git.tordarus.net/Tordarus/adverr"
_ "github.com/go-sql-driver/mysql"
)
type Database struct {
db *sql.DB
closefuncs []func() error
stmts map[string]*sql.Stmt
}
func NewDatabase(conn *sql.DB) *Database {
return &Database{
db: conn,
closefuncs: make([]func() error, 0),
func InitDatabase(database *Database, conn *sql.DB) {
panicOnNil(database)
if database.db != nil {
panic("database is initialized already")
}
*database = Database{
db: conn,
stmts: map[string]*sql.Stmt{},
}
database.init()
}
func NewMysqlDatabase(host string, port uint16, user, pass, db string) (*Database, error) {
func InitMysqlDatabase(database *Database, host string, port uint16, user, pass, db string) {
panicOnNil(database)
conn, err := sql.Open("mysql", connString(host, port, user, pass, db))
if err != nil {
return nil, adverr.Wrap("could not connect to database", err)
panic(err)
}
return NewDatabase(conn), nil
InitDatabase(database, conn)
}
func (db *Database) prepare(query string) (*sql.Stmt, error) {
s, err := db.db.Prepare(query)
if err != nil {
return nil, err
func panicOnNil(database *Database) {
if database == nil {
panic("database is nil. initialize database variable with new()")
}
db.closefuncs = append(db.closefuncs, s.Close)
return s, nil
}
func (db *Database) init() {
for _, globalQuery := range globalStmts {
if db == globalQuery.db {
db.stmt(globalQuery.query)
}
}
}
func (db *Database) stmt(query string) *sql.Stmt {
if stmt, ok := db.stmts[query]; ok {
return stmt
}
stmt, err := db.db.Prepare(query)
if err != nil {
panic(fmt.Errorf("compilation failed for query '%s' reason: %w", query, err))
}
db.stmts[query] = stmt
return stmt
}
func (db *Database) Close() error {
for _, close := range db.closefuncs {
close()
for _, stmt := range db.stmts {
stmt.Close()
}
return db.db.Close()
}

View File

@ -9,36 +9,31 @@ import (
_ "github.com/go-sql-driver/mysql"
)
var TestDatabase = new(Database)
type User struct {
Name string
Hash []byte
Salt string
}
func ScanUser(u *User, scan ScanFunc) error {
return scan(&u.Name, &u.Hash, &u.Salt)
func ScanUserPkFirst(u *User, encode ScanFunc) error {
return encode(&u.Name, &u.Hash, &u.Salt)
}
func ExecUser(u *User, exec ExecFunc) error {
return exec(u.Name, u.Hash, u.Salt)
func ScanUserPkLast(u *User, encode ScanFunc) error {
return encode(&u.Hash, &u.Salt, &u.Name)
}
func UpdateUser(u *User, exec ExecFunc) error {
return exec(u.Salt, u.Name)
}
var (
InsertUser = Insert(TestDatabase, "INSERT INTO users VALUES (?, ?, ?)", ScanUserPkFirst)
UpdateUser = Update(TestDatabase, "UPDATE users SET hash = ?, salt = ? WHERE name = ?", ScanUserPkLast)
GetUserByName = QueryOne(TestDatabase, "SELECT * FROM users WHERE namea = ?", ScanUserPkFirst)
)
func TestDB(t *testing.T) {
db, err := NewDatabase("192.168.178.2", 3306, "root", "ZAuJsLdPYFSxHdZon7xpMyh5LW7TPhmM", "users")
if err != nil {
t.Fatal(err)
}
defer db.Close()
insertUser := Insert(db, "INSERT INTO users VALUES (?, ?, ?)", ExecUser)
updateUser := Insert(db, "UPDATE users SET salt = ? WHERE name = ?", UpdateUser)
getUsers := Query(db, "SELECT * FROM users WHERE name = ?", ScanUser)
InitMysqlDatabase(TestDatabase, "ip", 3306, "username", "password", "database")
defer TestDatabase.Close()
pw := sha512.Sum512([]byte("weiter"))
timon := &User{
@ -46,13 +41,11 @@ func TestDB(t *testing.T) {
Hash: pw[:],
Salt: "salt",
}
fmt.Println("insert:", insertUser(timon))
fmt.Println("insert:", InsertUser(timon))
timon.Hash = []byte("asd")
fmt.Println("update:", updateUser(timon))
fmt.Println("update:", UpdateUser(timon))
for user := range getUsers("tordarus") {
user := GetUserByName("tordarus")
fmt.Printf("name: \"%s\" | hash: \"%s\" | salt: \"%s\"\n", user.Name, hex.EncodeToString(user.Hash), user.Salt)
}
}

12
default_decoders.go Normal file
View File

@ -0,0 +1,12 @@
package advsql
import "time"
type defaultDecoderType interface {
~int | ~int8 | ~int16 | ~int32 | ~int64 | ~uint | ~uint8 | ~uint16 | ~uint32 | ~uint64 | ~float32 | ~float64 | ~string | ~bool | time.Time
}
// Decoder provides default decoders for primitive datatypes
func Decoder[T defaultDecoderType](value *T, decode ScanFunc) error {
return decode(value)
}

11
delete.go Normal file
View File

@ -0,0 +1,11 @@
package advsql
func Delete(db *Database, query string) DeleteFunc {
prepareGlobal(db, query)
return func(args ...interface{}) error {
s := db.stmt(query)
_, err := s.Exec(args...)
return err
}
}

View File

@ -5,4 +5,12 @@ import "context"
type QueryManyFunc[T any] func(args ...interface{}) <-chan *T
type QueryManyContextFunc[T any] func(ctx context.Context, args ...interface{}) <-chan *T
type QueryOneFunc[T any] func(args ...interface{}) *T
type QueryOneContextFunc[T any] func(ctx context.Context, args ...interface{}) *T
type InsertFunc[T any] func(v *T) error
type UpdateFunc[T any] func(v *T) error
type DeleteFunc func(args ...interface{}) error
type ScanFunc = func(args ...interface{}) error

12
global_stmts.go Normal file
View File

@ -0,0 +1,12 @@
package advsql
type globalStmt struct {
db *Database
query string
}
var globalStmts = []*globalStmt{}
func prepareGlobal(db *Database, query string) {
globalStmts = append(globalStmts, &globalStmt{db, query})
}

7
go.mod
View File

@ -1,8 +1,5 @@
module git.tordarus.net/Tordarus/advsql
module git.milar.in/milarin/advsql
go 1.18
require (
git.tordarus.net/Tordarus/adverr v0.2.0
github.com/go-sql-driver/mysql v1.6.0
)
require github.com/go-sql-driver/mysql v1.6.0

2
go.sum
View File

@ -1,4 +1,2 @@
git.tordarus.net/Tordarus/adverr v0.2.0 h1:kLYjR2/Vb2GHiSAMvAv+WPNaHR9BRphKanf8H/pCZdA=
git.tordarus.net/Tordarus/adverr v0.2.0/go.mod h1:XRf0+7nhOkIEr0gi9DUG4RvV2KaOFB0fYPDaR1KLenw=
github.com/go-sql-driver/mysql v1.6.0 h1:BCTh4TKNUYmOmMUcQ3IipzF5prigylS7XXjEkfCHuOE=
github.com/go-sql-driver/mysql v1.6.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg=

View File

@ -1,13 +1,11 @@
package advsql
func Insert[T any](db *Database, query string, exec func(v *T, exec ExecFunc) error) InsertFunc[T] {
s, err := db.prepare(query)
if err != nil {
return nil
}
func Insert[T any](db *Database, query string, encoder func(v *T, encode ScanFunc) error) InsertFunc[T] {
prepareGlobal(db, query)
return func(value *T) error {
return exec(value, func(args ...interface{}) error {
s := db.stmt(query)
return encoder(value, func(args ...interface{}) error {
_, err := s.Exec(args...)
return err
})

View File

@ -1,3 +0,0 @@
# advsql
A more advanced ORM for the Golang sql package

View File

@ -1,15 +0,0 @@
package advsql
func (s *Store) checkTableIntegrity(value interface{}) (bool, error) {
valueStructure, err := objectStructure(value)
if err != nil {
return false, err
}
tableStructure, err := s.tableStructure(valueStructure.TypeName)
if err != nil {
return false, err
}
return compareStructures(valueStructure, tableStructure), nil
}

View File

@ -1,34 +0,0 @@
package advsql
import (
"log"
"testing"
_ "github.com/go-sql-driver/mysql"
)
func TestCheckTableIntegrity(t *testing.T) {
store, err := NewStoreMySQL("localhost", 3306, "kaikei", "u399D-TFulJykgRF4ijW6G7tK6zdO-jy", "kaikei")
if err != nil {
t.Fatal("could not make store", err)
}
type Person struct {
Surname string
Lastname string
Age int
}
p := &Person{
Surname: "Timon",
Lastname: "Ringwald",
Age: 25,
}
integrity, err := store.checkTableIntegrity(p)
if err != nil {
t.Fatal("could not check table integrity", err)
}
log.Println("integrity:", integrity)
}

View File

@ -1,96 +0,0 @@
package advsql
import (
"log"
"reflect"
"strings"
)
func (s *Store) createTable(structure *structure) error {
b := new(strings.Builder)
b.WriteString("CREATE TABLE ")
b.WriteString(structure.TypeName)
b.WriteString("(")
fieldAmount := len(structure.Fields)
for i, field := range structure.Fields {
b.WriteString(field.FieldName)
b.WriteString(" ")
b.WriteString(getTypeSpecForDB(field.FieldType))
if i < fieldAmount-1 {
b.WriteString(",")
}
}
b.WriteString(")")
log.Println("sql command: " + b.String())
_, err := s.db.Exec(b.String())
if err != nil {
return err
}
return nil
}
func getTypeSpecForDB(t reflect.Type) string {
noPtrType := t
dereferenced := false
// derefence all pointers recursively
for noPtrType.Kind() == reflect.Ptr {
noPtrType = t.Elem()
dereferenced = true
}
// actual data types
b := new(strings.Builder)
switch noPtrType.Kind() {
case reflect.Bool:
b.WriteString("INT")
case reflect.Int:
b.WriteString("BIGINT")
case reflect.Int8:
b.WriteString("TINYINT")
case reflect.Int16:
b.WriteString("SMALLINT")
case reflect.Int32:
b.WriteString("INT")
case reflect.Int64:
b.WriteString("BIGINT")
case reflect.Uint:
b.WriteString("BIGINT UNSIGNED")
case reflect.Uint8:
b.WriteString("TINYINT UNSIGNED")
case reflect.Uint16:
b.WriteString("SMALLINT UNSIGNED")
case reflect.Uint32:
b.WriteString("INT UNSIGNED")
case reflect.Uint64:
b.WriteString("BIGINT UNSIGNED")
case reflect.Float32:
b.WriteString("FLOAT")
case reflect.Float64:
b.WriteString("DOUBLE")
case reflect.String:
b.WriteString("TEXT")
case reflect.Ptr:
panic("something's broken. All pointers should have been dereferenced by now")
default:
panic("type spec for db not implemented yet: " + noPtrType.Kind().String())
}
if !dereferenced {
b.WriteString(" NOT NULL")
}
return b.String()
}

View File

@ -1,34 +0,0 @@
package advsql
import (
"testing"
)
func TestCreateTable(t *testing.T) {
store, err := NewStoreMySQL("localhost", 3306, "kaikei", "u399D-TFulJykgRF4ijW6G7tK6zdO-jy", "kaikei")
if err != nil {
t.Fatal("could not make store", err)
}
type Person struct {
Surname string
Lastname string
Age int
}
p := &Person{
Surname: "Timon",
Lastname: "Ringwald",
Age: 25,
}
structure, err := objectStructure(p)
if err != nil {
t.Fatal("Could not get structure of value", err)
}
err = store.createTable(structure)
if err != nil {
t.Fatal("Could not create table", err)
}
}

View File

@ -1,11 +0,0 @@
package advsql
func (s *Store) Insert(value interface{}) error {
// structure, err := objectStructure(value)
// if err != nil {
// return err
// }
// return nil
return nil
}

View File

@ -1,27 +0,0 @@
package advsql
import "testing"
func TestInsert(t *testing.T) {
store, err := NewStoreMySQL("localhost", 3306, "kaikei", "u399D-TFulJykgRF4ijW6G7tK6zdO-jy", "kaikei")
if err != nil {
t.Fatal("could not make store", err)
}
type Person struct {
Surname string
Lastname string
Age int
}
p := &Person{
Surname: "Timon",
Lastname: "Ringwald",
Age: 25,
}
err = store.Insert(p)
if err != nil {
t.Fatal("Could not create table", err)
}
}

View File

@ -1,35 +0,0 @@
package advsql
import (
"reflect"
"errors"
)
var ErrNoStructure = errors.New("value is not of type structure")
// analyze returns the structure of value
func objectStructure(value interface{}) (*structure, error) {
t := reflect.TypeOf(value)
k := t.Kind()
// only structs are allowed to be stored in DB
// if pointer: dereference and check again via recursion
if k == reflect.Ptr {
v := reflect.ValueOf(value)
return objectStructure(reflect.Indirect(v).Interface())
} else if k != reflect.Struct {
return nil, ErrNoStructure
}
fieldAmount := t.NumField()
structure := new(structure)
structure.TypeName = t.Name()
structure.Fields = make([]field, fieldAmount)
for fieldIndex := 0; fieldIndex < fieldAmount; fieldIndex++ {
reflectField := t.Field(fieldIndex)
structure.Fields[fieldIndex] = makeFieldUsingReflect(reflectField)
}
return structure, nil
}

View File

@ -1,47 +0,0 @@
package advsql
import (
"database/sql"
"fmt"
)
type Store struct {
db *sql.DB
}
func NewStoreMySQL(host string, port uint16, user, pass, db string) (*Store, error) {
dsn := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s", user, pass, host, port, db)
conn, err := sql.Open("mysql", dsn)
if err != nil {
return nil, err
}
return NewStore(conn), nil
}
func NewStore(db *sql.DB) *Store {
return &Store{
db: db,
}
}
func (s *Store) Close() error {
stmts := []*sql.Stmt{}
var err error
for _, stmt := range stmts {
err = stmt.Close()
if err != nil {
return err
}
}
err = s.db.Close()
if err != nil {
return err
}
return nil
}

View File

@ -1,29 +0,0 @@
package advsql
func (s *Store) tableStructure(tableName string) (*structure, error) {
rows, err := s.db.Query("SELECT * FROM " + tableName)
if err != nil {
return nil, err
}
defer rows.Close()
types, err := rows.ColumnTypes()
if err != nil {
return nil, err
}
structure := new(structure)
structure.TypeName = tableName
columnAmount := len(types)
structure.Fields = make([]field, columnAmount)
for columnIndex, columType := range types {
structure.Fields[columnIndex] = field{
FieldName: columType.Name(),
FieldType: columType.ScanType(),
}
}
return structure, nil
}

View File

@ -1,94 +0,0 @@
package advsql
import (
"log"
"reflect"
)
type structure struct {
TypeName string
Fields []field
}
type field struct {
FieldName string
FieldType reflect.Type
}
func makeFieldUsingReflect(reflectField reflect.StructField) field {
return field{
FieldName: reflectField.Name,
FieldType: reflectField.Type,
}
}
func compareStructures(first, second *structure) bool {
// compare names
if first.TypeName != second.TypeName {
log.Println("different names")
return false
}
firstFieldAmount, secondFieldAmount := len(first.Fields), len(second.Fields)
// compare field amounts
if firstFieldAmount != secondFieldAmount {
log.Println("different field amounts")
return false
}
// compare fields
for i := 0; i < firstFieldAmount; i++ {
firstField, secondField := first.Fields[i], second.Fields[i]
if !compareFields(firstField, secondField) {
log.Println("different field:", firstField, secondField)
return false
}
}
return true
}
func compareFields(first, second field) bool {
t1, t2 := first.FieldType, second.FieldType
// check if field names are equal
fieldNamesEqual := first.FieldName == second.FieldName
// exceptions for default comparison
switch true {
case checkBoth(first, second, isString, isByteSlice):
fallthrough
case isInt(first) && isInt(second):
fallthrough
case isUint(first) && isUint(second):
return fieldNamesEqual
default:
return t1 == t2
}
}
// TypeCheckFunc represents any function for type checking between Go and DB
type TypeCheckFunc func(field) bool
// checks if both check functions apply to either of the fields
func checkBoth(first, second field, firstCheck, secondCheck TypeCheckFunc) bool {
return firstCheck(first) && secondCheck(second) || firstCheck(second) && secondCheck(first)
}
func isString(field field) bool {
return field.FieldType.Kind() == reflect.String
}
func isInt(field field) bool {
return field.FieldType.Kind() == reflect.Int || field.FieldType.Kind() == reflect.Int64
}
func isUint(field field) bool {
return field.FieldType.Kind() == reflect.Uint || field.FieldType.Kind() == reflect.Uint64
}
func isByteSlice(field field) bool {
return field.FieldType.Kind() == reflect.Slice && field.FieldType.Elem().Kind() == reflect.Uint8
}

View File

@ -1,62 +0,0 @@
package advsql
import "context"
func Query[T any](db *Database, query string, scan func(v *T, scan ScanFunc) error) QueryManyFunc[T] {
s, err := db.prepare(query)
if err != nil {
return nil
}
return func(args ...interface{}) <-chan *T {
out := make(chan *T, 10)
rows, err := s.Query(args...)
if err != nil {
panic(err)
}
go func() {
defer rows.Close()
defer close(out)
for rows.Next() {
v := new(T)
if scan(v, rows.Scan) == nil {
out <- v
}
}
}()
return out
}
}
func QueryContext[T any](db *Database, query string, scan func(v *T, scan ScanFunc) error) QueryManyContextFunc[T] {
s, err := db.db.Prepare(query)
if err != nil {
return nil
}
db.closefuncs = append(db.closefuncs, s.Close)
return func(ctx context.Context, args ...interface{}) <-chan *T {
out := make(chan *T, 10)
rows, err := s.QueryContext(ctx, args...)
if err != nil {
panic(err)
}
go func() {
defer rows.Close()
defer close(out)
for rows.Next() {
v := new(T)
if scan(v, rows.Scan) == nil {
out <- v
}
}
}()
return out
}
}

38
query_many.go Normal file
View File

@ -0,0 +1,38 @@
package advsql
import "context"
func QueryMany[T any](db *Database, query string, decoder func(v *T, decode ScanFunc) error) QueryManyFunc[T] {
ctxfunc := QueryManyContext(db, query, decoder)
return func(args ...interface{}) <-chan *T {
return ctxfunc(context.Background(), args...)
}
}
func QueryManyContext[T any](db *Database, query string, decoder func(v *T, decode ScanFunc) error) QueryManyContextFunc[T] {
prepareGlobal(db, query)
return func(ctx context.Context, args ...interface{}) <-chan *T {
s := db.stmt(query)
out := make(chan *T, 10)
rows, err := s.QueryContext(ctx, args...)
if err != nil {
panic(err)
}
go func() {
defer rows.Close()
defer close(out)
for rows.Next() {
v := new(T)
if decoder(v, rows.Scan) == nil {
out <- v
}
}
}()
return out
}
}

19
query_one.go Normal file
View File

@ -0,0 +1,19 @@
package advsql
import "context"
func QueryOne[T any](db *Database, query string, decoder func(v *T, decode ScanFunc) error) QueryOneFunc[T] {
ctxfunc := QueryOneContext(db, query, decoder)
return func(args ...interface{}) *T {
return ctxfunc(context.Background(), args...)
}
}
func QueryOneContext[T any](db *Database, query string, decoder func(v *T, decode ScanFunc) error) QueryOneContextFunc[T] {
manyfunc := QueryManyContext(db, query, decoder)
return func(ctx context.Context, args ...interface{}) *T {
nctx, cancel := context.WithCancel(ctx)
defer cancel()
return <-manyfunc(nctx, args...)
}
}

View File

@ -1,5 +0,0 @@
package advsql
type ScanFunc func(args ...interface{}) error
type ExecFunc func(args ...interface{}) error

59
stmt.go
View File

@ -1,59 +0,0 @@
package advsql
// import (
// "database/sql"
// )
// type Stmt[M any] struct {
// stmt *sql.Stmt
// scan func(Scanner) (*M, error)
// }
// func NewStmt[M any](db *Database, query string, scan func(s Scanner) (*M, error)) (*Stmt[M], error) {
// s, err := db.db.Prepare(query)
// if err != nil {
// return nil, err
// }
// stmt := Stmt[M]{
// stmt: s,
// scan: scan,
// }
// db.closefuncs = append(db.closefuncs, stmt.Close)
// return &stmt, err
// }
// func (stmt *Stmt[M]) Close() error {
// return stmt.stmt.Close()
// }
// func (stmt *Stmt[M]) Many(args ...interface{}) (<-chan *M, error) {
// rows, err := stmt.stmt.Query(args...)
// if err != nil {
// return nil, err
// }
// out := make(chan *M, 10)
// go func() {
// defer rows.Close()
// defer close(out)
// for rows.Next() {
// if v, err := stmt.scan(rows); err == nil {
// out <- v
// }
// }
// }()
// return out, nil
// }
// func (stmt *Stmt[M]) Single(args ...interface{}) (*M, error) {
// rows, err := stmt.stmt.Query(args...)
// if err != nil {
// return nil, err
// }
// defer rows.Close()
// return stmt.scan(rows)
// }

View File

@ -1,5 +1,5 @@
package advsql
func Update[T any](db *Database, query string, exec func(v *T, exec ExecFunc) error) InsertFunc[T] {
return Insert(db, query, exec)
func Update[T any](db *Database, query string, encoder func(v *T, encode ScanFunc) error) UpdateFunc[T] {
return UpdateFunc[T](Insert(db, query, encoder))
}

View File

@ -3,5 +3,5 @@ package advsql
import "fmt"
func connString(host string, port uint16, user, pass, db string) string {
return fmt.Sprintf("%s:%s@tcp(%s:%d)/%s", user, pass, host, port, db)
return fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?parseTime=true", user, pass, host, port, db)
}