package main import ( "embed" "encoding/json" "errors" "fmt" "html/template" "net/http" "net/url" "strings" "time" "git.milar.in/milarin/advsql" "git.milar.in/milarin/channel" "git.milar.in/milarin/envvars/v2" "github.com/gorilla/mux" ) var ( //go:embed templates/* TemplateFS embed.FS //go:embed static/* StaticFS embed.FS Templates *template.Template ) func main() { advsql.InitMysqlDatabase( Database, envvars.String("DB_HOST", ""), envvars.Uint16("DB_PORT", 3306), envvars.String("DB_USER", ""), envvars.String("DB_PASS", ""), envvars.String("DB_BASE", "startpage"), ) defer Database.Close() if tmpl, err := template.New("homepage").Funcs(tmplFuncs).ParseFS(TemplateFS, "templates/*"); err == nil { Templates = tmpl } else { panic(err) } r := mux.NewRouter() r.HandleFunc("/", handler) r.HandleFunc("/customize", customize) r.HandleFunc("/save-changes", saveChanges) r.HandleFunc("/search", search) r.PathPrefix("/static/").Handler(http.FileServer(http.FS(StaticFS))) go func() { for range time.Tick(time.Hour) { if err := DeleteSessionsByExpirationDateBefore(time.Now()); err != nil { fmt.Println("could not delete expired sessions", err) } } }() if err := http.ListenAndServe(fmt.Sprintf("%s:%d", envvars.String("HTTP_INTF", ""), envvars.Uint16("HTTP_PORT", 80)), r); err != nil { panic(err) } } func handler(w http.ResponseWriter, r *http.Request) { session, err := GetSession(w, r) if err != nil { w.WriteHeader(http.StatusInternalServerError) fmt.Println(err) return } data := &TmplData{ text: GetText(r), Bookmarks: channel.ToSliceDeref(GetBookmarksBySessionIdOrdered(session.ID)), Settings: GetSettingsBySessionID(session.ID), } if err := Templates.ExecuteTemplate(w, "index.html", data); err != nil { panic(err) } } func search(w http.ResponseWriter, r *http.Request) { session, err := GetSession(w, r) if err != nil { w.WriteHeader(http.StatusInternalServerError) fmt.Println(err) return } settings := GetSettingsBySessionID(session.ID) if err := r.ParseForm(); err != nil { fmt.Println(err) w.WriteHeader(http.StatusInternalServerError) return } query := r.Form.Get("query") if uri, err := ParseURI(query); err == nil { // url w.Header().Add("Location", uri.String()) w.WriteHeader(http.StatusMovedPermanently) } else { // search string w.Header().Add("Location", fmt.Sprintf(settings.Search, query)) w.WriteHeader(http.StatusMovedPermanently) } } func customize(w http.ResponseWriter, r *http.Request) { session, err := GetSession(w, r) if err != nil { w.WriteHeader(http.StatusInternalServerError) fmt.Println(err) return } text := GetText(r) data := &CustomizeData{ text: text, Settings: GetSettingsBySessionID(session.ID), Bookmarks: channel.ToSliceDeref(channel.MapSuccessive(GetBookmarksBySessionIdOrdered(session.ID), func(b *Bookmark) *BookmarkData { return &BookmarkData{ Bookmark: b, text: text, } })), } if err := Templates.ExecuteTemplate(w, "customize.html", data); err != nil { panic(err) } } func saveChanges(w http.ResponseWriter, r *http.Request) { if r.Method != "POST" { w.WriteHeader(http.StatusMethodNotAllowed) return } session, err := GetSession(w, r) if err != nil { w.WriteHeader(http.StatusInternalServerError) fmt.Println(err) return } sessionData := &SessionData{} err = json.NewDecoder(r.Body).Decode(&sessionData) if err != nil { fmt.Println(err) w.WriteHeader(http.StatusInternalServerError) return } Reorder(sessionData.Bookmarks) if err := DeleteBookmarksBySessionID(session.ID); err != nil { fmt.Println(err) w.WriteHeader(http.StatusInternalServerError) return } for _, bookmark := range sessionData.Bookmarks { bookmark.SessionID = session.ID if err := InsertBookmark(&bookmark); err != nil { fmt.Println(err) w.WriteHeader(http.StatusInternalServerError) return } } sessionData.Settings.SessionID = session.ID if err := UpdateSettings(sessionData.Settings); err != nil { fmt.Println(err) w.WriteHeader(http.StatusInternalServerError) return } } func ParseURI(uri string) (*url.URL, error) { if !strings.HasPrefix(uri, "http://") && !strings.HasPrefix(uri, "https://") { uri = "https://" + uri } ret, err := url.ParseRequestURI(uri) if err != nil { return nil, err } splits := strings.Split(ret.Hostname(), ".") if len(splits) <= 1 { return nil, errors.New("hostname doesn't have a TLD") } if _, ok := TLDs[splits[len(splits)-1]]; !ok { return nil, errors.New("invalid top level domain") } return ret, nil }