From a6d99c0ccc4692980cad380fbe34270115c03cba Mon Sep 17 00:00:00 2001 From: Timon Ringwald Date: Tue, 6 Sep 2022 11:22:28 +0200 Subject: [PATCH] mysql db for session storage --- bookmark.go | 4 +- db.go | 90 ++++++++++++++++++++++++++++ go.mod | 10 +--- go.sum | 12 ++-- main.go | 137 +++++++++++++++++++++---------------------- settings.go | 2 + templates/index.html | 6 ++ utils.go | 85 +++++++++++++++++++++------ 8 files changed, 242 insertions(+), 104 deletions(-) create mode 100644 db.go diff --git a/bookmark.go b/bookmark.go index 7312da0..dd8aa35 100644 --- a/bookmark.go +++ b/bookmark.go @@ -78,7 +78,9 @@ func DefaultBookmarks() []Bookmark { } type Bookmark struct { - Title string `json:"title"` + SessionID string `json:"-"` + Title string `json:"title"` + Image string `json:"image"` ImageSize string `json:"image_size"` IconPadding string `json:"icon_padding"` diff --git a/db.go b/db.go new file mode 100644 index 0000000..5139ac6 --- /dev/null +++ b/db.go @@ -0,0 +1,90 @@ +package main + +import ( + "time" + + "git.milar.in/milarin/advsql" +) + +var Database = &advsql.Database{} + +var ( + InsertSession = advsql.Insert(Database, "INSERT INTO sessions VALUES (?, ?, ?)", ScanSessionPkFirst) + UpdateSession = advsql.Update(Database, "UPDATE sessions SET expiration_date = ?, creation_date = ? WHERE id = ?", ScanSessionPkLast) + GetSessionByID = advsql.QueryOne(Database, "SELECT * FROM sessions WHERE id = ?", ScanSessionPkFirst) + + InsertBookmark = advsql.Insert(Database, "INSERT INTO bookmarks VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)", ScanBookmarkPkFirst) + UpdateBookmark = advsql.Update(Database, "UPDATE bookmarks SET link = ?, image = ?, color = ?, image_size = ?, icon_padding = ?, hide_border = ?, order_priority = ? WHERE session_id = ? AND title = ?", ScanBookmarkPkLast) + GetBookmarkBySessionIDAndTitle = advsql.QueryOne(Database, "SELECT * FROM bookmarks WHERE session_id = ? AND title = ?", ScanBookmarkPkFirst) + GetBookmarksBySessionIdOrdered = advsql.QueryMany(Database, "SELECT * FROM bookmarks WHERE session_id = ? ORDER BY order_priority", ScanBookmarkPkFirst) + DeleteBookmarksBySessionID = advsql.Delete(Database, "DELETE FROM bookmarks WHERE session_id = ?") + + InsertSettings = advsql.Insert(Database, "INSERT INTO settings VALUES (?, ?, ?, ?, ?)", ScanSettingsPkFirst) + UpdateSettings = advsql.Update(Database, "UPDATE settings SET background_color = ?, foreground_color = ?, search_query = ?, border_radius = ? WHERE session_id = ?", ScanSettingsPkLast) + GetSettingsBySessionID = advsql.QueryOne(Database, "SELECT * FROM settings WHERE session_id = ?", ScanSettingsPkFirst) +) + +type Session struct { + ID string + + ExpirationDate time.Time + CreationDate time.Time +} + +func ScanSessionPkFirst(u *Session, encode advsql.ScanFunc) error { + return encode(&u.ID, &u.ExpirationDate, &u.CreationDate) +} + +func ScanSessionPkLast(u *Session, encode advsql.ScanFunc) error { + return encode(&u.ExpirationDate, &u.CreationDate, &u.ID) +} + +func NewSession() (*Session, error) { + id, err := generateRandomID(32) + if err != nil { + return nil, err + } + + return &Session{ + ID: id, + + CreationDate: time.Now(), + ExpirationDate: time.Now().AddDate(1, 0, 0), + }, nil +} + +func ScanBookmarkPkFirst(u *Bookmark, encode advsql.ScanFunc) error { + return encode( + &u.SessionID, + &u.Title, + &u.Link, + &u.Image, + &u.Color, + &u.ImageSize, + &u.IconPadding, + &u.HideBorder, + &u.Order, + ) +} + +func ScanBookmarkPkLast(u *Bookmark, encode advsql.ScanFunc) error { + return encode( + &u.Link, + &u.Image, + &u.Color, + &u.ImageSize, + &u.IconPadding, + &u.HideBorder, + &u.Order, + &u.SessionID, + &u.Title, + ) +} + +func ScanSettingsPkFirst(u *Settings, encode advsql.ScanFunc) error { + return encode(&u.SessionID, &u.Background, &u.Foreground, &u.Search, &u.BorderRadius) +} + +func ScanSettingsPkLast(u *Settings, encode advsql.ScanFunc) error { + return encode(&u.Background, &u.Foreground, &u.Search, &u.BorderRadius, &u.SessionID) +} diff --git a/go.mod b/go.mod index 06bfc43..5bb93c3 100644 --- a/go.mod +++ b/go.mod @@ -3,15 +3,11 @@ module git.milar.in/milarin/startpage go 1.18 require ( + git.milar.in/milarin/advsql v0.0.16 + git.milar.in/milarin/channel v0.0.10 git.milar.in/milarin/envvars/v2 v2.0.0 - git.milar.in/milarin/slices v0.0.4 github.com/gorilla/mux v1.8.0 - github.com/gorilla/sessions v1.2.1 - github.com/srinathgs/mysqlstore v0.0.0-20200417050510-9cbb9420fc4c golang.org/x/text v0.3.7 ) -require ( - github.com/go-sql-driver/mysql v1.6.0 // indirect - github.com/gorilla/securecookie v1.1.1 // indirect -) +require github.com/go-sql-driver/mysql v1.6.0 // indirect diff --git a/go.sum b/go.sum index 06c5ffe..6a4db12 100644 --- a/go.sum +++ b/go.sum @@ -1,16 +1,12 @@ +git.milar.in/milarin/advsql v0.0.16 h1:vPGPM4Vyb8aR8qcj479JXpC+MlUjeFsPZ6CUIoKkC+I= +git.milar.in/milarin/advsql v0.0.16/go.mod h1:bdQrtqZaGkxmbYRi2OhF5T8DDDztnvbE41faIVSZPLQ= +git.milar.in/milarin/channel v0.0.10 h1:oHCx69lMF/KIot38OgfVtZ9oBmv5FadZp2MMDPPnC3o= +git.milar.in/milarin/channel v0.0.10/go.mod h1:We83LTI8S7u7II3pD+A2ChCDWJfCkcBUCUqii9HjTtM= git.milar.in/milarin/envvars/v2 v2.0.0 h1:DWRQCWaHqzDD8NGpSgv5tYLuF9A/dVFPAtTvz3oiIqE= git.milar.in/milarin/envvars/v2 v2.0.0/go.mod h1:HkdEi+gG2lJSmVq547bTlQV4qQ0hO333bE8IrE0B9yY= -git.milar.in/milarin/slices v0.0.4 h1:z92jgsKcnLPLfgXkTVCzH2XXesfXzhe0Osx+PkfCHVI= -git.milar.in/milarin/slices v0.0.4/go.mod h1:NOr53AOeur/qscu/FBj3lsFR262PNYBccLYSTCAXRk4= github.com/go-sql-driver/mysql v1.6.0 h1:BCTh4TKNUYmOmMUcQ3IipzF5prigylS7XXjEkfCHuOE= github.com/go-sql-driver/mysql v1.6.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg= github.com/gorilla/mux v1.8.0 h1:i40aqfkR1h2SlN9hojwV5ZA91wcXFOvkdNIeFDP5koI= github.com/gorilla/mux v1.8.0/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB71So= -github.com/gorilla/securecookie v1.1.1 h1:miw7JPhV+b/lAHSXz4qd/nN9jRiAFV5FwjeKyCS8BvQ= -github.com/gorilla/securecookie v1.1.1/go.mod h1:ra0sb63/xPlUeL+yeDciTfxMRAA+MP+HVt/4epWDjd4= -github.com/gorilla/sessions v1.2.1 h1:DHd3rPN5lE3Ts3D8rKkQ8x/0kqfeNmBAaiSi+o7FsgI= -github.com/gorilla/sessions v1.2.1/go.mod h1:dk2InVEVJ0sfLlnXv9EAgkf6ecYs/i80K/zI+bUmuGM= -github.com/srinathgs/mysqlstore v0.0.0-20200417050510-9cbb9420fc4c h1:HT6QRF79dL2Ed6HCrX9RufkxFGo7+NPkgYF1Uzvv/js= -github.com/srinathgs/mysqlstore v0.0.0-20200417050510-9cbb9420fc4c/go.mod h1:kt46Hd+lF0rtpeRgOvYSWYJItOAd73EKkIBZFbX7TXs= golang.org/x/text v0.3.7 h1:olpwvP2KacW1ZWvsR7uQhoyTYvKAupfQrRGBFM352Gk= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= diff --git a/main.go b/main.go index 41cf12e..ef5d106 100644 --- a/main.go +++ b/main.go @@ -2,31 +2,20 @@ package main import ( "embed" - "encoding/gob" - "encoding/hex" "encoding/json" "errors" - "flag" "fmt" "html/template" - "io" "net/http" "net/url" - "os" "strings" - "time" _ "embed" + "git.milar.in/milarin/advsql" + "git.milar.in/milarin/channel" "git.milar.in/milarin/envvars/v2" - "git.milar.in/milarin/slices" "github.com/gorilla/mux" - "github.com/srinathgs/mysqlstore" -) - -var ( - intf = flag.String("i", "", "interface") - port = flag.Uint("p", 80, "port") ) var ( @@ -37,28 +26,18 @@ var ( StaticFS embed.FS Templates *template.Template - - SessionStore = Must(mysqlstore.NewMySQLStore( - fmt.Sprintf( - "%s:%s@tcp(%s:%d)/%s?parseTime=true&loc=Local", - envvars.String("DB_USER", ""), - envvars.String("DB_PASS", ""), - envvars.String("DB_HOST", ""), - envvars.Uint16("DB_PORT", 3306), - envvars.String("DB_BASE", "startpage"), - ), - "sessions", - "/", - int((time.Hour * 24 * 365).Seconds()), - Must(hex.DecodeString(os.Getenv("SESSION_KEY"))), - )) ) func main() { - gob.Register([]Bookmark{}) - gob.Register(&Settings{}) - - flag.Parse() + advsql.InitMysqlDatabase( + Database, + envvars.String("DB_HOST", ""), + envvars.Uint16("DB_PORT", 3306), + envvars.String("DB_USER", ""), + envvars.String("DB_PASS", ""), + envvars.String("DB_BASE", "startpage"), + ) + defer Database.Close() if tmpl, err := template.New("homepage").ParseFS(TemplateFS, "templates/*"); err == nil { Templates = tmpl @@ -73,18 +52,23 @@ func main() { r.HandleFunc("/search", search) r.PathPrefix("/static/").Handler(http.FileServer(http.FS(StaticFS))) - if err := http.ListenAndServe(fmt.Sprintf("%s:%d", *intf, *port), r); err != nil { + if err := http.ListenAndServe(fmt.Sprintf("%s:%d", envvars.String("HTTP_INTF", ""), envvars.Uint16("HTTP_PORT", 80)), r); err != nil { panic(err) } } func handler(w http.ResponseWriter, r *http.Request) { - session, _ := SessionStore.Get(r, "settings") + session, err := GetSession(w, r) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + fmt.Println(err) + return + } data := &TmplData{ text: GetText(r), - Bookmarks: GetValueDefault(session, "bookmarks", DefaultBookmarks()), - Settings: GetValueDefault(session, "settings", DefaultSettings()), + Bookmarks: channel.ToSliceDeref(GetBookmarksBySessionIdOrdered(session.ID)), + Settings: GetSettingsBySessionID(session.ID), } if err := Templates.ExecuteTemplate(w, "index.html", data); err != nil { @@ -92,23 +76,14 @@ func handler(w http.ResponseWriter, r *http.Request) { } } -func ProvideFile(path, contentType string) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - file, err := StaticFS.Open(path) - if err != nil { - fmt.Println(err) - w.WriteHeader(http.StatusInternalServerError) - return - } - defer file.Close() - w.Header().Add("Content-Type", contentType) - io.Copy(w, file) - } -} - func search(w http.ResponseWriter, r *http.Request) { - session, _ := SessionStore.Get(r, "settings") - settings := GetValueDefault(session, "settings", DefaultSettings()) + session, err := GetSession(w, r) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + fmt.Println(err) + return + } + settings := GetSettingsBySessionID(session.ID) if err := r.ParseForm(); err != nil { fmt.Println(err) @@ -131,21 +106,24 @@ func search(w http.ResponseWriter, r *http.Request) { } func customize(w http.ResponseWriter, r *http.Request) { - session, _ := SessionStore.Get(r, "settings") + session, err := GetSession(w, r) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + fmt.Println(err) + return + } text := GetText(r) - bookmarks := slices.Map(GetValueDefault(session, "bookmarks", DefaultBookmarks()), func(b Bookmark) BookmarkData { - return BookmarkData{ - Bookmark: &b, - text: text, - } - }) - data := &CustomizeData{ - text: text, - Bookmarks: bookmarks, - Settings: GetValueDefault(session, "settings", DefaultSettings()), + text: text, + Settings: GetSettingsBySessionID(session.ID), + Bookmarks: channel.ToSliceDeref(channel.MapSuccessive(GetBookmarksBySessionIdOrdered(session.ID), func(b *Bookmark) *BookmarkData { + return &BookmarkData{ + Bookmark: b, + text: text, + } + })), } if err := Templates.ExecuteTemplate(w, "customize.html", data); err != nil { @@ -159,21 +137,38 @@ func saveChanges(w http.ResponseWriter, r *http.Request) { return } - session, _ := SessionStore.Get(r, "settings") - - sessionData := &SessionData{} - err := json.NewDecoder(r.Body).Decode(&sessionData) + session, err := GetSession(w, r) if err != nil { w.WriteHeader(http.StatusInternalServerError) fmt.Println(err) return } - Reorder(sessionData.Bookmarks) - session.Values["bookmarks"] = sessionData.Bookmarks - session.Values["settings"] = sessionData.Settings + sessionData := &SessionData{} + err = json.NewDecoder(r.Body).Decode(&sessionData) + if err != nil { + fmt.Println(err) + w.WriteHeader(http.StatusInternalServerError) + return + } - if err := session.Save(r, w); err != nil { + Reorder(sessionData.Bookmarks) + if err := DeleteBookmarksBySessionID(session.ID); err != nil { + fmt.Println(err) + w.WriteHeader(http.StatusInternalServerError) + return + } + for _, bookmark := range sessionData.Bookmarks { + bookmark.SessionID = session.ID + if err := InsertBookmark(&bookmark); err != nil { + fmt.Println(err) + w.WriteHeader(http.StatusInternalServerError) + return + } + } + + sessionData.Settings.SessionID = session.ID + if err := UpdateSettings(sessionData.Settings); err != nil { fmt.Println(err) w.WriteHeader(http.StatusInternalServerError) return diff --git a/settings.go b/settings.go index 62fb63e..230184e 100644 --- a/settings.go +++ b/settings.go @@ -1,6 +1,8 @@ package main type Settings struct { + SessionID string `json:"-"` + Background string `json:"background_color"` Foreground string `json:"foreground_color"` Search string `json:"search_query"` diff --git a/templates/index.html b/templates/index.html index 861c756..b37d79e 100644 --- a/templates/index.html +++ b/templates/index.html @@ -16,6 +16,12 @@ background-color: {{ .Settings.Background }}; color: {{ .Settings.Foreground }}; } + + {{ if not .Settings.Search }} + #search { + display: none; + } + {{ end }} diff --git a/utils.go b/utils.go index 14cea34..999b2ac 100644 --- a/utils.go +++ b/utils.go @@ -1,7 +1,10 @@ package main import ( - "github.com/gorilla/sessions" + crand "crypto/rand" + "encoding/hex" + "net/http" + "time" ) func Must[T any](value T, err error) T { @@ -11,25 +14,73 @@ func Must[T any](value T, err error) T { return value } -func GetValue[K any, T any](session *sessions.Session, key K) (T, bool) { - value, ok := session.Values[key] - if !ok && value == nil { - return *new(T), false +func generateRandomID(length int) (string, error) { + data := make([]byte, 32) + if _, err := crand.Read(data); err != nil { + return "", err } - - castedValue, ok := value.(T) - if !ok { - return *new(T), false - } - - return castedValue, true + return hex.EncodeToString(data), nil } -func GetValueDefault[K any, T any](session *sessions.Session, key K, defaultValue T) T { - v, ok := GetValue[K, T](session, key) - if !ok { - return defaultValue +func GetSession(w http.ResponseWriter, r *http.Request) (*Session, error) { + cookie, err := r.Cookie("session") + if err != nil { + return makeNewSession(w) } - return v + session := GetSessionByID(cookie.Value) + if session == nil { + return makeNewSession(w) + } + + // update cookie expiration date + session.ExpirationDate = time.Now().AddDate(1, 0, 0) + if err := UpdateSession(session); err != nil { + return nil, err + } + http.SetCookie(w, &http.Cookie{ + Name: "session", + Value: session.ID, + Expires: session.ExpirationDate, + Secure: true, + HttpOnly: true, + }) + + return session, nil +} + +func makeNewSession(w http.ResponseWriter) (*Session, error) { + session, err := NewSession() + if err != nil { + return nil, err + } + + if err := InsertSession(session); err != nil { + return nil, err + } + + // insert default bookmarks + for _, bookmark := range DefaultBookmarks() { + bookmark.SessionID = session.ID + if err := InsertBookmark(&bookmark); err != nil { + return nil, err + } + } + + // insert default settings + settings := DefaultSettings() + settings.SessionID = session.ID + if err := InsertSettings(settings); err != nil { + return nil, err + } + + http.SetCookie(w, &http.Cookie{ + Name: "session", + Value: session.ID, + Expires: session.ExpirationDate, + Secure: true, + HttpOnly: true, + }) + + return session, nil }