diff --git a/db_test.go b/db_test.go index 44a03c9..444bdcd 100644 --- a/db_test.go +++ b/db_test.go @@ -36,7 +36,7 @@ func TestDB(t *testing.T) { insertUser := Insert(db, "INSERT INTO users VALUES (?, ?, ?)", InsertUserEncoder) - updateUser := Insert(db, "UPDATE users SET salt = ? WHERE name = ?", UpdateUserByNameEncoder) + updateUser := Update(db, "UPDATE users SET salt = ? WHERE name = ?", UpdateUserByNameEncoder) getUsers := Query(db, "SELECT * FROM users WHERE name = ?", UserDecoder) diff --git a/default_decoders.go b/default_decoders.go index 2af1019..29947e9 100644 --- a/default_decoders.go +++ b/default_decoders.go @@ -1,45 +1,12 @@ package advsql -func IntDecoder(value *int, decode DecodeFunc) error { - return decode(value) +import "time" + +type defaultDecoderType interface { + ~int | ~int8 | ~int16 | ~int32 | ~int64 | ~uint | ~uint8 | ~uint16 | ~uint32 | ~uint64 | ~float32 | ~float64 | ~string | ~bool | time.Time } -func Int8Decoder(value *int8, decode DecodeFunc) error { - return decode(value) -} - -func Int16Decoder(value *int16, decode DecodeFunc) error { - return decode(value) -} - -func Int32Decoder(value *int32, decode DecodeFunc) error { - return decode(value) -} - -func Int64Decoder(value *int64, decode DecodeFunc) error { - return decode(value) -} - -func Uint8Decoder(value *uint8, decode DecodeFunc) error { - return decode(value) -} - -func Uint16Decoder(value *uint16, decode DecodeFunc) error { - return decode(value) -} - -func Uint32Decoder(value *uint32, decode DecodeFunc) error { - return decode(value) -} - -func Uint64Decoder(value *uint64, decode DecodeFunc) error { - return decode(value) -} - -func Float32Decoder(value *float32, decode DecodeFunc) error { - return decode(value) -} - -func Float64Decoder(value *float64, decode DecodeFunc) error { +// Decoder provides default decoders for primitive datatypes +func Decoder[T defaultDecoderType](value *T, decode DecodeFunc) error { return decode(value) } diff --git a/func_types.go b/func_types.go index 9807a14..3a7253b 100644 --- a/func_types.go +++ b/func_types.go @@ -5,4 +5,8 @@ import "context" type QueryManyFunc[T any] func(args ...interface{}) <-chan *T type QueryManyContextFunc[T any] func(ctx context.Context, args ...interface{}) <-chan *T +type QueryOneFunc[T any] func(args ...interface{}) *T +type QueryOneContextFunc[T any] func(ctx context.Context, args ...interface{}) *T + type InsertFunc[T any] func(v *T) error +type UpdateFunc[T any] func(v *T) error diff --git a/query.go b/query.go deleted file mode 100644 index 1442c9b..0000000 --- a/query.go +++ /dev/null @@ -1,62 +0,0 @@ -package advsql - -import "context" - -func Query[T any](db *Database, query string, decoder func(v *T, decode DecodeFunc) error) QueryManyFunc[T] { - s, err := db.prepare(query) - if err != nil { - panic(err) - } - - return func(args ...interface{}) <-chan *T { - out := make(chan *T, 10) - - rows, err := s.Query(args...) - if err != nil { - panic(err) - } - - go func() { - defer rows.Close() - defer close(out) - for rows.Next() { - v := new(T) - if decoder(v, rows.Scan) == nil { - out <- v - } - } - }() - - return out - } -} - -func QueryContext[T any](db *Database, query string, decoder func(v *T, decode DecodeFunc) error) QueryManyContextFunc[T] { - s, err := db.db.Prepare(query) - if err != nil { - return nil - } - db.closefuncs = append(db.closefuncs, s.Close) - - return func(ctx context.Context, args ...interface{}) <-chan *T { - out := make(chan *T, 10) - - rows, err := s.QueryContext(ctx, args...) - if err != nil { - panic(err) - } - - go func() { - defer rows.Close() - defer close(out) - for rows.Next() { - v := new(T) - if decoder(v, rows.Scan) == nil { - out <- v - } - } - }() - - return out - } -} diff --git a/query_many.go b/query_many.go new file mode 100644 index 0000000..51ea7ac --- /dev/null +++ b/query_many.go @@ -0,0 +1,40 @@ +package advsql + +import "context" + +func QueryMany[T any](db *Database, query string, decoder func(v *T, decode DecodeFunc) error) QueryManyFunc[T] { + ctxfunc := QueryManyContext(db, query, decoder) + return func(args ...interface{}) <-chan *T { + return ctxfunc(context.Background(), args...) + } +} + +func QueryManyContext[T any](db *Database, query string, decoder func(v *T, decode DecodeFunc) error) QueryManyContextFunc[T] { + s, err := db.db.Prepare(query) + if err != nil { + return nil + } + db.closefuncs = append(db.closefuncs, s.Close) + + return func(ctx context.Context, args ...interface{}) <-chan *T { + out := make(chan *T, 10) + + rows, err := s.QueryContext(ctx, args...) + if err != nil { + panic(err) + } + + go func() { + defer rows.Close() + defer close(out) + for rows.Next() { + v := new(T) + if decoder(v, rows.Scan) == nil { + out <- v + } + } + }() + + return out + } +} diff --git a/query_one.go b/query_one.go new file mode 100644 index 0000000..59d6a11 --- /dev/null +++ b/query_one.go @@ -0,0 +1,19 @@ +package advsql + +import "context" + +func QueryOne[T any](db *Database, query string, decoder func(v *T, decode DecodeFunc) error) QueryOneFunc[T] { + ctxfunc := QueryOneContext(db, query, decoder) + return func(args ...interface{}) *T { + return ctxfunc(context.Background(), args...) + } +} + +func QueryOneContext[T any](db *Database, query string, decoder func(v *T, decode DecodeFunc) error) QueryOneContextFunc[T] { + manyfunc := QueryManyContext(db, query, decoder) + return func(ctx context.Context, args ...interface{}) *T { + nctx, cancel := context.WithCancel(ctx) + defer cancel() + return <-manyfunc(nctx, args...) + } +} diff --git a/update.go b/update.go index b97ad72..0ca514c 100644 --- a/update.go +++ b/update.go @@ -1,5 +1,5 @@ package advsql -func Update[T any](db *Database, query string, encoder func(v *T, encode EncodeFunc) error) InsertFunc[T] { - return Insert(db, query, encoder) +func Update[T any](db *Database, query string, encoder func(v *T, encode EncodeFunc) error) UpdateFunc[T] { + return UpdateFunc[T](Insert(db, query, encoder)) }