diff --git a/check_table_integrity.go b/check_table_integrity.go new file mode 100644 index 0000000..46ba52b --- /dev/null +++ b/check_table_integrity.go @@ -0,0 +1,15 @@ +package advsql + +func (s *Store) checkTableIntegrity(value interface{}) (bool, error) { + valueStructure, err := objectStructure(value) + if err != nil { + return false, err + } + + tableStructure, err := s.tableStructure(valueStructure.TypeName) + if err != nil { + return false, err + } + + return compareStructures(valueStructure, tableStructure), nil +} diff --git a/check_table_integrity_test.go b/check_table_integrity_test.go new file mode 100644 index 0000000..3cfbc28 --- /dev/null +++ b/check_table_integrity_test.go @@ -0,0 +1,34 @@ +package advsql + +import ( + "log" + "testing" + + _ "github.com/go-sql-driver/mysql" +) + +func TestCheckTableIntegrity(t *testing.T) { + store, err := NewStoreMySQL("localhost", 3306, "kaikei", "u399D-TFulJykgRF4ijW6G7tK6zdO-jy", "kaikei") + if err != nil { + t.Fatal("could not make store", err) + } + + type Person struct { + Surname string + Lastname string + Age int + } + + p := &Person{ + Surname: "Timon", + Lastname: "Ringwald", + Age: 25, + } + + integrity, err := store.checkTableIntegrity(p) + if err != nil { + t.Fatal("could not check table integrity", err) + } + + log.Println("integrity:", integrity) +} diff --git a/create_table.go b/create_table.go new file mode 100644 index 0000000..fab1400 --- /dev/null +++ b/create_table.go @@ -0,0 +1,96 @@ +package advsql + +import ( + "log" + "reflect" + "strings" +) + +func (s *Store) createTable(structure *structure) error { + b := new(strings.Builder) + b.WriteString("CREATE TABLE ") + b.WriteString(structure.TypeName) + b.WriteString("(") + + fieldAmount := len(structure.Fields) + + for i, field := range structure.Fields { + b.WriteString(field.FieldName) + b.WriteString(" ") + b.WriteString(getTypeSpecForDB(field.FieldType)) + + if i < fieldAmount-1 { + b.WriteString(",") + } + } + + b.WriteString(")") + + log.Println("sql command: " + b.String()) + + _, err := s.db.Exec(b.String()) + if err != nil { + return err + } + + return nil +} + +func getTypeSpecForDB(t reflect.Type) string { + noPtrType := t + dereferenced := false + + // derefence all pointers recursively + for noPtrType.Kind() == reflect.Ptr { + noPtrType = t.Elem() + dereferenced = true + } + + // actual data types + b := new(strings.Builder) + switch noPtrType.Kind() { + case reflect.Bool: + b.WriteString("INT") + + case reflect.Int: + b.WriteString("BIGINT") + case reflect.Int8: + b.WriteString("TINYINT") + case reflect.Int16: + b.WriteString("SMALLINT") + case reflect.Int32: + b.WriteString("INT") + case reflect.Int64: + b.WriteString("BIGINT") + + case reflect.Uint: + b.WriteString("BIGINT UNSIGNED") + case reflect.Uint8: + b.WriteString("TINYINT UNSIGNED") + case reflect.Uint16: + b.WriteString("SMALLINT UNSIGNED") + case reflect.Uint32: + b.WriteString("INT UNSIGNED") + case reflect.Uint64: + b.WriteString("BIGINT UNSIGNED") + + case reflect.Float32: + b.WriteString("FLOAT") + case reflect.Float64: + b.WriteString("DOUBLE") + + case reflect.String: + b.WriteString("TEXT") + + case reflect.Ptr: + panic("something's broken. All pointers should have been dereferenced by now") + default: + panic("type spec for db not implemented yet: " + noPtrType.Kind().String()) + } + + if !dereferenced { + b.WriteString(" NOT NULL") + } + + return b.String() +} diff --git a/create_table_test.go b/create_table_test.go new file mode 100644 index 0000000..ed1f837 --- /dev/null +++ b/create_table_test.go @@ -0,0 +1,34 @@ +package advsql + +import ( + "testing" +) + +func TestCreateTable(t *testing.T) { + store, err := NewStoreMySQL("localhost", 3306, "kaikei", "u399D-TFulJykgRF4ijW6G7tK6zdO-jy", "kaikei") + if err != nil { + t.Fatal("could not make store", err) + } + + type Person struct { + Surname string + Lastname string + Age int + } + + p := &Person{ + Surname: "Timon", + Lastname: "Ringwald", + Age: 25, + } + + structure, err := objectStructure(p) + if err != nil { + t.Fatal("Could not get structure of value", err) + } + + err = store.createTable(structure) + if err != nil { + t.Fatal("Could not create table", err) + } +} diff --git a/insert.go b/insert.go new file mode 100644 index 0000000..ee6e9ae --- /dev/null +++ b/insert.go @@ -0,0 +1,10 @@ +package advsql + +func (s *Store) Insert(value interface{}) error { + structure, err := objectStructure(value) + if err != nil { + return err + } + + return nil +} diff --git a/insert_test.go b/insert_test.go new file mode 100644 index 0000000..5681a13 --- /dev/null +++ b/insert_test.go @@ -0,0 +1,27 @@ +package advsql + +import "testing" + +func TestInsert(t *testing.T) { + store, err := NewStoreMySQL("localhost", 3306, "kaikei", "u399D-TFulJykgRF4ijW6G7tK6zdO-jy", "kaikei") + if err != nil { + t.Fatal("could not make store", err) + } + + type Person struct { + Surname string + Lastname string + Age int + } + + p := &Person{ + Surname: "Timon", + Lastname: "Ringwald", + Age: 25, + } + + err = store.Insert(p) + if err != nil { + t.Fatal("Could not create table", err) + } +} diff --git a/object_structure.go b/object_structure.go new file mode 100644 index 0000000..46b2334 --- /dev/null +++ b/object_structure.go @@ -0,0 +1,35 @@ +package advsql + +import ( + "reflect" + "errors" +) + +var ErrNoStructure = errors.New("value is not of type structure") + +// analyze returns the structure of value +func objectStructure(value interface{}) (*structure, error) { + t := reflect.TypeOf(value) + k := t.Kind() + + // only structs are allowed to be stored in DB + // if pointer: dereference and check again via recursion + if k == reflect.Ptr { + v := reflect.ValueOf(value) + return objectStructure(reflect.Indirect(v).Interface()) + } else if k != reflect.Struct { + return nil, ErrNoStructure + } + + fieldAmount := t.NumField() + structure := new(structure) + structure.TypeName = t.Name() + structure.Fields = make([]field, fieldAmount) + + for fieldIndex := 0; fieldIndex < fieldAmount; fieldIndex++ { + reflectField := t.Field(fieldIndex) + structure.Fields[fieldIndex] = makeFieldUsingReflect(reflectField) + } + + return structure, nil +} \ No newline at end of file diff --git a/store.go b/store.go new file mode 100644 index 0000000..0ed9ccb --- /dev/null +++ b/store.go @@ -0,0 +1,47 @@ +package advsql + +import ( + "database/sql" + "fmt" +) + +type Store struct { + db *sql.DB +} + +func NewStoreMySQL(host string, port uint16, user, pass, db string) (*Store, error) { + dsn := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s", user, pass, host, port, db) + + conn, err := sql.Open("mysql", dsn) + if err != nil { + return nil, err + } + + return NewStore(conn), nil +} + +func NewStore(db *sql.DB) *Store { + return &Store{ + db: db, + } +} + +func (s *Store) Close() error { + stmts := []*sql.Stmt{} + + var err error + + for _, stmt := range stmts { + err = stmt.Close() + if err != nil { + return err + } + } + + err = s.db.Close() + if err != nil { + return err + } + + return nil +} diff --git a/table_structure.go b/table_structure.go new file mode 100644 index 0000000..07b5b8b --- /dev/null +++ b/table_structure.go @@ -0,0 +1,29 @@ +package advsql + +func (s *Store) tableStructure(tableName string) (*structure, error) { + rows, err := s.db.Query("SELECT * FROM " + tableName) + if err != nil { + return nil, err + } + defer rows.Close() + + types, err := rows.ColumnTypes() + if err != nil { + return nil, err + } + + structure := new(structure) + structure.TypeName = tableName + + columnAmount := len(types) + structure.Fields = make([]field, columnAmount) + + for columnIndex, columType := range types { + structure.Fields[columnIndex] = field{ + FieldName: columType.Name(), + FieldType: columType.ScanType(), + } + } + + return structure, nil +} \ No newline at end of file diff --git a/types.go b/types.go new file mode 100644 index 0000000..f3ed1f1 --- /dev/null +++ b/types.go @@ -0,0 +1,94 @@ +package advsql + +import ( + "log" + "reflect" +) + +type structure struct { + TypeName string + Fields []field +} + +type field struct { + FieldName string + FieldType reflect.Type +} + +func makeFieldUsingReflect(reflectField reflect.StructField) field { + return field{ + FieldName: reflectField.Name, + FieldType: reflectField.Type, + } +} + +func compareStructures(first, second *structure) bool { + // compare names + if first.TypeName != second.TypeName { + log.Println("different names") + return false + } + + firstFieldAmount, secondFieldAmount := len(first.Fields), len(second.Fields) + + // compare field amounts + if firstFieldAmount != secondFieldAmount { + log.Println("different field amounts") + return false + } + + // compare fields + for i := 0; i < firstFieldAmount; i++ { + firstField, secondField := first.Fields[i], second.Fields[i] + if !compareFields(firstField, secondField) { + log.Println("different field:", firstField, secondField) + return false + } + } + + return true +} + +func compareFields(first, second field) bool { + t1, t2 := first.FieldType, second.FieldType + + // check if field names are equal + fieldNamesEqual := first.FieldName == second.FieldName + + // exceptions for default comparison + switch true { + case checkBoth(first, second, isString, isByteSlice): + fallthrough + case isInt(first) && isInt(second): + fallthrough + case isUint(first) && isUint(second): + return fieldNamesEqual + + default: + return t1 == t2 + } +} + +// TypeCheckFunc represents any function for type checking between Go and DB +type TypeCheckFunc func(field) bool + +// checks if both check functions apply to either of the fields +func checkBoth(first, second field, firstCheck, secondCheck TypeCheckFunc) bool { + return firstCheck(first) && secondCheck(second) || firstCheck(second) && secondCheck(first) +} + +func isString(field field) bool { + return field.FieldType.Kind() == reflect.String +} + +func isInt(field field) bool { + return field.FieldType.Kind() == reflect.Int || field.FieldType.Kind() == reflect.Int64 +} + +func isUint(field field) bool { + return field.FieldType.Kind() == reflect.Uint || field.FieldType.Kind() == reflect.Uint64 +} + +func isByteSlice(field field) bool { + return field.FieldType.Kind() == reflect.Slice && field.FieldType.Elem().Kind() == reflect.Uint8 +}