From d6e73f69c0aa61abab33f88ce7ddd6045b83d275 Mon Sep 17 00:00:00 2001 From: Timon Ringwald Date: Tue, 5 Jul 2022 12:38:39 +0200 Subject: [PATCH] rewrite --- README.md | 3 -- check_table_integrity.go | 15 ------ check_table_integrity_test.go | 34 ------------- create_table.go | 96 ----------------------------------- create_table_test.go | 34 ------------- db.go | 41 +++++++++++++++ db_test.go | 58 +++++++++++++++++++++ func_types.go | 8 +++ go.mod | 8 +++ go.sum | 4 ++ insert.go | 13 +++-- insert_test.go | 27 ---------- object_structure.go | 35 ------------- query.go | 62 ++++++++++++++++++++++ scanner.go | 5 ++ stmt.go | 59 +++++++++++++++++++++ store.go | 47 ----------------- table_structure.go | 29 ----------- types.go | 94 ---------------------------------- update.go | 5 ++ utils.go | 7 +++ 21 files changed, 266 insertions(+), 418 deletions(-) delete mode 100644 README.md delete mode 100644 check_table_integrity.go delete mode 100644 check_table_integrity_test.go delete mode 100644 create_table.go delete mode 100644 create_table_test.go create mode 100644 db.go create mode 100644 db_test.go create mode 100644 func_types.go create mode 100644 go.mod create mode 100644 go.sum delete mode 100644 insert_test.go delete mode 100644 object_structure.go create mode 100644 query.go create mode 100644 scanner.go create mode 100644 stmt.go delete mode 100644 store.go delete mode 100644 table_structure.go delete mode 100644 types.go create mode 100644 update.go create mode 100644 utils.go diff --git a/README.md b/README.md deleted file mode 100644 index 352ffa9..0000000 --- a/README.md +++ /dev/null @@ -1,3 +0,0 @@ -# advsql - -A more advanced ORM for the Golang sql package \ No newline at end of file diff --git a/check_table_integrity.go b/check_table_integrity.go deleted file mode 100644 index 46ba52b..0000000 --- a/check_table_integrity.go +++ /dev/null @@ -1,15 +0,0 @@ -package advsql - -func (s *Store) checkTableIntegrity(value interface{}) (bool, error) { - valueStructure, err := objectStructure(value) - if err != nil { - return false, err - } - - tableStructure, err := s.tableStructure(valueStructure.TypeName) - if err != nil { - return false, err - } - - return compareStructures(valueStructure, tableStructure), nil -} diff --git a/check_table_integrity_test.go b/check_table_integrity_test.go deleted file mode 100644 index 3cfbc28..0000000 --- a/check_table_integrity_test.go +++ /dev/null @@ -1,34 +0,0 @@ -package advsql - -import ( - "log" - "testing" - - _ "github.com/go-sql-driver/mysql" -) - -func TestCheckTableIntegrity(t *testing.T) { - store, err := NewStoreMySQL("localhost", 3306, "kaikei", "u399D-TFulJykgRF4ijW6G7tK6zdO-jy", "kaikei") - if err != nil { - t.Fatal("could not make store", err) - } - - type Person struct { - Surname string - Lastname string - Age int - } - - p := &Person{ - Surname: "Timon", - Lastname: "Ringwald", - Age: 25, - } - - integrity, err := store.checkTableIntegrity(p) - if err != nil { - t.Fatal("could not check table integrity", err) - } - - log.Println("integrity:", integrity) -} diff --git a/create_table.go b/create_table.go deleted file mode 100644 index fab1400..0000000 --- a/create_table.go +++ /dev/null @@ -1,96 +0,0 @@ -package advsql - -import ( - "log" - "reflect" - "strings" -) - -func (s *Store) createTable(structure *structure) error { - b := new(strings.Builder) - b.WriteString("CREATE TABLE ") - b.WriteString(structure.TypeName) - b.WriteString("(") - - fieldAmount := len(structure.Fields) - - for i, field := range structure.Fields { - b.WriteString(field.FieldName) - b.WriteString(" ") - b.WriteString(getTypeSpecForDB(field.FieldType)) - - if i < fieldAmount-1 { - b.WriteString(",") - } - } - - b.WriteString(")") - - log.Println("sql command: " + b.String()) - - _, err := s.db.Exec(b.String()) - if err != nil { - return err - } - - return nil -} - -func getTypeSpecForDB(t reflect.Type) string { - noPtrType := t - dereferenced := false - - // derefence all pointers recursively - for noPtrType.Kind() == reflect.Ptr { - noPtrType = t.Elem() - dereferenced = true - } - - // actual data types - b := new(strings.Builder) - switch noPtrType.Kind() { - case reflect.Bool: - b.WriteString("INT") - - case reflect.Int: - b.WriteString("BIGINT") - case reflect.Int8: - b.WriteString("TINYINT") - case reflect.Int16: - b.WriteString("SMALLINT") - case reflect.Int32: - b.WriteString("INT") - case reflect.Int64: - b.WriteString("BIGINT") - - case reflect.Uint: - b.WriteString("BIGINT UNSIGNED") - case reflect.Uint8: - b.WriteString("TINYINT UNSIGNED") - case reflect.Uint16: - b.WriteString("SMALLINT UNSIGNED") - case reflect.Uint32: - b.WriteString("INT UNSIGNED") - case reflect.Uint64: - b.WriteString("BIGINT UNSIGNED") - - case reflect.Float32: - b.WriteString("FLOAT") - case reflect.Float64: - b.WriteString("DOUBLE") - - case reflect.String: - b.WriteString("TEXT") - - case reflect.Ptr: - panic("something's broken. All pointers should have been dereferenced by now") - default: - panic("type spec for db not implemented yet: " + noPtrType.Kind().String()) - } - - if !dereferenced { - b.WriteString(" NOT NULL") - } - - return b.String() -} diff --git a/create_table_test.go b/create_table_test.go deleted file mode 100644 index ed1f837..0000000 --- a/create_table_test.go +++ /dev/null @@ -1,34 +0,0 @@ -package advsql - -import ( - "testing" -) - -func TestCreateTable(t *testing.T) { - store, err := NewStoreMySQL("localhost", 3306, "kaikei", "u399D-TFulJykgRF4ijW6G7tK6zdO-jy", "kaikei") - if err != nil { - t.Fatal("could not make store", err) - } - - type Person struct { - Surname string - Lastname string - Age int - } - - p := &Person{ - Surname: "Timon", - Lastname: "Ringwald", - Age: 25, - } - - structure, err := objectStructure(p) - if err != nil { - t.Fatal("Could not get structure of value", err) - } - - err = store.createTable(structure) - if err != nil { - t.Fatal("Could not create table", err) - } -} diff --git a/db.go b/db.go new file mode 100644 index 0000000..da1e5fa --- /dev/null +++ b/db.go @@ -0,0 +1,41 @@ +package advsql + +import ( + "database/sql" + + "git.tordarus.net/Tordarus/adverr" + _ "github.com/go-sql-driver/mysql" +) + +type Database struct { + db *sql.DB + closefuncs []func() error +} + +func NewDatabase(host string, port uint16, user, pass, db string) (*Database, error) { + conn, err := sql.Open("mysql", connString(host, port, user, pass, db)) + if err != nil { + return nil, adverr.Wrap("could not connect to database", err) + } + + return &Database{ + db: conn, + closefuncs: make([]func() error, 0), + }, nil +} + +func (db *Database) prepare(query string) (*sql.Stmt, error) { + s, err := db.db.Prepare(query) + if err != nil { + return nil, err + } + db.closefuncs = append(db.closefuncs, s.Close) + return s, nil +} + +func (db *Database) Close() error { + for _, close := range db.closefuncs { + close() + } + return db.db.Close() +} diff --git a/db_test.go b/db_test.go new file mode 100644 index 0000000..e36c191 --- /dev/null +++ b/db_test.go @@ -0,0 +1,58 @@ +package advsql + +import ( + "crypto/sha512" + "encoding/hex" + "fmt" + "testing" + + _ "github.com/go-sql-driver/mysql" +) + +type User struct { + Name string + Hash []byte + Salt string +} + +func ScanUser(u *User, scan ScanFunc) error { + return scan(&u.Name, &u.Hash, &u.Salt) +} + +func ExecUser(u *User, exec ExecFunc) error { + return exec(u.Name, u.Hash, u.Salt) +} + +func UpdateUser(u *User, exec ExecFunc) error { + return exec(u.Salt, u.Name) +} + +func TestDB(t *testing.T) { + db, err := NewDatabase("192.168.178.2", 3306, "root", "ZAuJsLdPYFSxHdZon7xpMyh5LW7TPhmM", "users") + if err != nil { + t.Fatal(err) + } + defer db.Close() + + insertUser := Insert(db, "INSERT INTO users VALUES (?, ?, ?)", ExecUser) + + updateUser := Insert(db, "UPDATE users SET salt = ? WHERE name = ?", UpdateUser) + + getUsers := Query(db, "SELECT * FROM users WHERE name = ?", ScanUser) + + pw := sha512.Sum512([]byte("weiter")) + timon := &User{ + Name: "timon", + Hash: pw[:], + Salt: "salt", + } + fmt.Println("insert:", insertUser(timon)) + + timon.Hash = []byte("asd") + fmt.Println("update:", updateUser(timon)) + + for user := range getUsers("tordarus") { + fmt.Printf("name: \"%s\" | hash: \"%s\" | salt: \"%s\"\n", user.Name, hex.EncodeToString(user.Hash), user.Salt) + } + +} diff --git a/func_types.go b/func_types.go new file mode 100644 index 0000000..9807a14 --- /dev/null +++ b/func_types.go @@ -0,0 +1,8 @@ +package advsql + +import "context" + +type QueryManyFunc[T any] func(args ...interface{}) <-chan *T +type QueryManyContextFunc[T any] func(ctx context.Context, args ...interface{}) <-chan *T + +type InsertFunc[T any] func(v *T) error diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..8b279a8 --- /dev/null +++ b/go.mod @@ -0,0 +1,8 @@ +module git.tordarus.net/Tordarus/advsql + +go 1.18 + +require ( + git.tordarus.net/Tordarus/adverr v0.2.0 + github.com/go-sql-driver/mysql v1.6.0 +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..184b59a --- /dev/null +++ b/go.sum @@ -0,0 +1,4 @@ +git.tordarus.net/Tordarus/adverr v0.2.0 h1:kLYjR2/Vb2GHiSAMvAv+WPNaHR9BRphKanf8H/pCZdA= +git.tordarus.net/Tordarus/adverr v0.2.0/go.mod h1:XRf0+7nhOkIEr0gi9DUG4RvV2KaOFB0fYPDaR1KLenw= +github.com/go-sql-driver/mysql v1.6.0 h1:BCTh4TKNUYmOmMUcQ3IipzF5prigylS7XXjEkfCHuOE= +github.com/go-sql-driver/mysql v1.6.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg= diff --git a/insert.go b/insert.go index ee6e9ae..8fa7e92 100644 --- a/insert.go +++ b/insert.go @@ -1,10 +1,15 @@ package advsql -func (s *Store) Insert(value interface{}) error { - structure, err := objectStructure(value) +func Insert[T any](db *Database, query string, exec func(v *T, exec ExecFunc) error) InsertFunc[T] { + s, err := db.prepare(query) if err != nil { - return err + return nil } - return nil + return func(value *T) error { + return exec(value, func(args ...interface{}) error { + _, err := s.Exec(args...) + return err + }) + } } diff --git a/insert_test.go b/insert_test.go deleted file mode 100644 index 5681a13..0000000 --- a/insert_test.go +++ /dev/null @@ -1,27 +0,0 @@ -package advsql - -import "testing" - -func TestInsert(t *testing.T) { - store, err := NewStoreMySQL("localhost", 3306, "kaikei", "u399D-TFulJykgRF4ijW6G7tK6zdO-jy", "kaikei") - if err != nil { - t.Fatal("could not make store", err) - } - - type Person struct { - Surname string - Lastname string - Age int - } - - p := &Person{ - Surname: "Timon", - Lastname: "Ringwald", - Age: 25, - } - - err = store.Insert(p) - if err != nil { - t.Fatal("Could not create table", err) - } -} diff --git a/object_structure.go b/object_structure.go deleted file mode 100644 index 46b2334..0000000 --- a/object_structure.go +++ /dev/null @@ -1,35 +0,0 @@ -package advsql - -import ( - "reflect" - "errors" -) - -var ErrNoStructure = errors.New("value is not of type structure") - -// analyze returns the structure of value -func objectStructure(value interface{}) (*structure, error) { - t := reflect.TypeOf(value) - k := t.Kind() - - // only structs are allowed to be stored in DB - // if pointer: dereference and check again via recursion - if k == reflect.Ptr { - v := reflect.ValueOf(value) - return objectStructure(reflect.Indirect(v).Interface()) - } else if k != reflect.Struct { - return nil, ErrNoStructure - } - - fieldAmount := t.NumField() - structure := new(structure) - structure.TypeName = t.Name() - structure.Fields = make([]field, fieldAmount) - - for fieldIndex := 0; fieldIndex < fieldAmount; fieldIndex++ { - reflectField := t.Field(fieldIndex) - structure.Fields[fieldIndex] = makeFieldUsingReflect(reflectField) - } - - return structure, nil -} \ No newline at end of file diff --git a/query.go b/query.go new file mode 100644 index 0000000..a711004 --- /dev/null +++ b/query.go @@ -0,0 +1,62 @@ +package advsql + +import "context" + +func Query[T any](db *Database, query string, scan func(v *T, scan ScanFunc) error) QueryManyFunc[T] { + s, err := db.prepare(query) + if err != nil { + return nil + } + + return func(args ...interface{}) <-chan *T { + out := make(chan *T, 10) + + rows, err := s.Query(args...) + if err != nil { + panic(err) + } + + go func() { + defer rows.Close() + defer close(out) + for rows.Next() { + v := new(T) + if scan(v, rows.Scan) == nil { + out <- v + } + } + }() + + return out + } +} + +func QueryContext[T any](db *Database, query string, scan func(v *T, scan ScanFunc) error) QueryManyContextFunc[T] { + s, err := db.db.Prepare(query) + if err != nil { + return nil + } + db.closefuncs = append(db.closefuncs, s.Close) + + return func(ctx context.Context, args ...interface{}) <-chan *T { + out := make(chan *T, 10) + + rows, err := s.QueryContext(ctx, args...) + if err != nil { + panic(err) + } + + go func() { + defer rows.Close() + defer close(out) + for rows.Next() { + v := new(T) + if scan(v, rows.Scan) == nil { + out <- v + } + } + }() + + return out + } +} diff --git a/scanner.go b/scanner.go new file mode 100644 index 0000000..e3cf9b6 --- /dev/null +++ b/scanner.go @@ -0,0 +1,5 @@ +package advsql + +type ScanFunc func(args ...interface{}) error + +type ExecFunc func(args ...interface{}) error diff --git a/stmt.go b/stmt.go new file mode 100644 index 0000000..d759cc2 --- /dev/null +++ b/stmt.go @@ -0,0 +1,59 @@ +package advsql + +// import ( +// "database/sql" +// ) + +// type Stmt[M any] struct { +// stmt *sql.Stmt +// scan func(Scanner) (*M, error) +// } + +// func NewStmt[M any](db *Database, query string, scan func(s Scanner) (*M, error)) (*Stmt[M], error) { +// s, err := db.db.Prepare(query) +// if err != nil { +// return nil, err +// } + +// stmt := Stmt[M]{ +// stmt: s, +// scan: scan, +// } +// db.closefuncs = append(db.closefuncs, stmt.Close) +// return &stmt, err +// } + +// func (stmt *Stmt[M]) Close() error { +// return stmt.stmt.Close() +// } + +// func (stmt *Stmt[M]) Many(args ...interface{}) (<-chan *M, error) { +// rows, err := stmt.stmt.Query(args...) +// if err != nil { +// return nil, err +// } + +// out := make(chan *M, 10) + +// go func() { +// defer rows.Close() +// defer close(out) + +// for rows.Next() { +// if v, err := stmt.scan(rows); err == nil { +// out <- v +// } +// } +// }() + +// return out, nil +// } + +// func (stmt *Stmt[M]) Single(args ...interface{}) (*M, error) { +// rows, err := stmt.stmt.Query(args...) +// if err != nil { +// return nil, err +// } +// defer rows.Close() +// return stmt.scan(rows) +// } diff --git a/store.go b/store.go deleted file mode 100644 index 0ed9ccb..0000000 --- a/store.go +++ /dev/null @@ -1,47 +0,0 @@ -package advsql - -import ( - "database/sql" - "fmt" -) - -type Store struct { - db *sql.DB -} - -func NewStoreMySQL(host string, port uint16, user, pass, db string) (*Store, error) { - dsn := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s", user, pass, host, port, db) - - conn, err := sql.Open("mysql", dsn) - if err != nil { - return nil, err - } - - return NewStore(conn), nil -} - -func NewStore(db *sql.DB) *Store { - return &Store{ - db: db, - } -} - -func (s *Store) Close() error { - stmts := []*sql.Stmt{} - - var err error - - for _, stmt := range stmts { - err = stmt.Close() - if err != nil { - return err - } - } - - err = s.db.Close() - if err != nil { - return err - } - - return nil -} diff --git a/table_structure.go b/table_structure.go deleted file mode 100644 index 07b5b8b..0000000 --- a/table_structure.go +++ /dev/null @@ -1,29 +0,0 @@ -package advsql - -func (s *Store) tableStructure(tableName string) (*structure, error) { - rows, err := s.db.Query("SELECT * FROM " + tableName) - if err != nil { - return nil, err - } - defer rows.Close() - - types, err := rows.ColumnTypes() - if err != nil { - return nil, err - } - - structure := new(structure) - structure.TypeName = tableName - - columnAmount := len(types) - structure.Fields = make([]field, columnAmount) - - for columnIndex, columType := range types { - structure.Fields[columnIndex] = field{ - FieldName: columType.Name(), - FieldType: columType.ScanType(), - } - } - - return structure, nil -} \ No newline at end of file diff --git a/types.go b/types.go deleted file mode 100644 index f3ed1f1..0000000 --- a/types.go +++ /dev/null @@ -1,94 +0,0 @@ -package advsql - -import ( - "log" - "reflect" -) - -type structure struct { - TypeName string - Fields []field -} - -type field struct { - FieldName string - FieldType reflect.Type -} - -func makeFieldUsingReflect(reflectField reflect.StructField) field { - return field{ - FieldName: reflectField.Name, - FieldType: reflectField.Type, - } -} - -func compareStructures(first, second *structure) bool { - // compare names - if first.TypeName != second.TypeName { - log.Println("different names") - return false - } - - firstFieldAmount, secondFieldAmount := len(first.Fields), len(second.Fields) - - // compare field amounts - if firstFieldAmount != secondFieldAmount { - log.Println("different field amounts") - return false - } - - // compare fields - for i := 0; i < firstFieldAmount; i++ { - firstField, secondField := first.Fields[i], second.Fields[i] - if !compareFields(firstField, secondField) { - log.Println("different field:", firstField, secondField) - return false - } - } - - return true -} - -func compareFields(first, second field) bool { - t1, t2 := first.FieldType, second.FieldType - - // check if field names are equal - fieldNamesEqual := first.FieldName == second.FieldName - - // exceptions for default comparison - switch true { - case checkBoth(first, second, isString, isByteSlice): - fallthrough - case isInt(first) && isInt(second): - fallthrough - case isUint(first) && isUint(second): - return fieldNamesEqual - - default: - return t1 == t2 - } -} - -// TypeCheckFunc represents any function for type checking between Go and DB -type TypeCheckFunc func(field) bool - -// checks if both check functions apply to either of the fields -func checkBoth(first, second field, firstCheck, secondCheck TypeCheckFunc) bool { - return firstCheck(first) && secondCheck(second) || firstCheck(second) && secondCheck(first) -} - -func isString(field field) bool { - return field.FieldType.Kind() == reflect.String -} - -func isInt(field field) bool { - return field.FieldType.Kind() == reflect.Int || field.FieldType.Kind() == reflect.Int64 -} - -func isUint(field field) bool { - return field.FieldType.Kind() == reflect.Uint || field.FieldType.Kind() == reflect.Uint64 -} - -func isByteSlice(field field) bool { - return field.FieldType.Kind() == reflect.Slice && field.FieldType.Elem().Kind() == reflect.Uint8 -} diff --git a/update.go b/update.go new file mode 100644 index 0000000..cf885f8 --- /dev/null +++ b/update.go @@ -0,0 +1,5 @@ +package advsql + +func Update[T any](db *Database, query string, exec func(v *T, exec ExecFunc) error) InsertFunc[T] { + return nil +} diff --git a/utils.go b/utils.go new file mode 100644 index 0000000..25f35e5 --- /dev/null +++ b/utils.go @@ -0,0 +1,7 @@ +package advsql + +import "fmt" + +func connString(host string, port uint16, user, pass, db string) string { + return fmt.Sprintf("%s:%s@tcp(%s:%d)/%s", user, pass, host, port, db) +}