Skip to content

Commit

Permalink
Merge branch 'topic/refactoring'
Browse files Browse the repository at this point in the history
  • Loading branch information
equinox0815 committed Nov 19, 2023
2 parents c507512 + e442788 commit baea797
Show file tree
Hide file tree
Showing 7 changed files with 127 additions and 134 deletions.
36 changes: 15 additions & 21 deletions cmd/whawty-nginx-sso/web.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,19 +56,19 @@ type HandlerContext struct {
auth auth.Backend
}

func (h *HandlerContext) verifyCookie(c *gin.Context) (string, *cookie.Session, error) {
func (h *HandlerContext) verifyCookie(c *gin.Context) (*cookie.Session, error) {
cookie, err := c.Cookie(h.cookies.Options().Name)
if err != nil {
return "", nil, err
return nil, err
}
if cookie == "" {
return "", nil, errors.New("no cookie found")
return nil, errors.New("no cookie found")
}
id, session, err := h.cookies.Verify(cookie)
session, err := h.cookies.Verify(cookie)
if err != nil {
return "", nil, err
return nil, err
}
return id, &session, nil
return &session, nil
}

func (h *HandlerContext) getBasePath(c *gin.Context) string {
Expand All @@ -83,7 +83,7 @@ func (h *HandlerContext) getBasePath(c *gin.Context) string {
}

func (h *HandlerContext) handleAuth(c *gin.Context) {
_, session, err := h.verifyCookie(c)
session, err := h.verifyCookie(c)
if err != nil {
c.Data(http.StatusUnauthorized, "text/plain", []byte(err.Error()))
return
Expand All @@ -96,13 +96,12 @@ func (h *HandlerContext) handleLoginGet(c *gin.Context) {
login := h.conf.Login
login.BasePath = h.getBasePath(c)

_, session, err := h.verifyCookie(c)
session, err := h.verifyCookie(c)
if err == nil {
// TODO: follow redir?
c.HTML(http.StatusOK, "logged-in.htmpl", pongo2.Context{
"login": login,
"username": session.Username,
"expires": time.Unix(session.Expires, 0),
"login": login,
"session": session,
})
return
}
Expand Down Expand Up @@ -141,7 +140,7 @@ func (h *HandlerContext) handleLoginPost(c *gin.Context) {
return
}

value, opts, err := h.cookies.New(cookie.Session{Username: username})
value, opts, err := h.cookies.New(username)
if err != nil {
c.HTML(http.StatusBadRequest, "login.htmpl", pongo2.Context{
"login": login,
Expand All @@ -153,20 +152,15 @@ func (h *HandlerContext) handleLoginPost(c *gin.Context) {
c.SetCookie(opts.Name, value, opts.MaxAge, "/", opts.Domain, opts.Secure, true)

if redirect == "" {
c.HTML(http.StatusOK, "logged-in.htmpl", pongo2.Context{
"login": login,
"username": username,
"expires": time.Now().Add(time.Duration(opts.MaxAge) * time.Second),
})
return
redirect = path.Join(h.getBasePath(c), "login")
}
c.Redirect(http.StatusSeeOther, redirect)
}

func (h *HandlerContext) handleLogout(c *gin.Context) {
id, session, err := h.verifyCookie(c)
session, err := h.verifyCookie(c)
if err == nil {
if err = h.cookies.Revoke(id, *session); err != nil {
if err = h.cookies.Revoke(*session); err != nil {
// TODO: render error page!
c.JSON(http.StatusInternalServerError, WebError{err.Error()})
return
Expand All @@ -182,7 +176,7 @@ func (h *HandlerContext) handleLogout(c *gin.Context) {
}

func (h *HandlerContext) handleSessions(c *gin.Context) {
_, session, err := h.verifyCookie(c)
session, err := h.verifyCookie(c)
if err != nil {
c.JSON(http.StatusUnauthorized, WebError{err.Error()})
return
Expand Down
40 changes: 20 additions & 20 deletions cookie/backend_in-memory.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,38 +40,38 @@ import (
type InMemoryBackendConfig struct {
}

type InMemorySessionList map[ulid.ULID]Session
type InMemorySessionMap map[ulid.ULID]SessionBase

type InMemoryBackend struct {
mutex sync.RWMutex
sessions map[string]InMemorySessionList
revoked InMemorySessionList
sessions map[string]InMemorySessionMap
revoked InMemorySessionMap
}

func NewInMemoryBackend(conf *InMemoryBackendConfig) (*InMemoryBackend, error) {
m := &InMemoryBackend{}
m.sessions = make(map[string]InMemorySessionList)
m.revoked = make(InMemorySessionList)
m.sessions = make(map[string]InMemorySessionMap)
m.revoked = make(InMemorySessionMap)
return m, nil
}

func (b *InMemoryBackend) Save(id ulid.ULID, session Session) error {
func (b *InMemoryBackend) Save(session Session) error {
b.mutex.Lock()
defer b.mutex.Unlock()

sessions, exists := b.sessions[session.Username]
if !exists {
sessions = make(InMemorySessionList)
sessions = make(InMemorySessionMap)
b.sessions[session.Username] = sessions
}
if _, exists = sessions[id]; exists {
return fmt.Errorf("session '%v' already exists!", id)
if _, exists = sessions[session.ID]; exists {
return fmt.Errorf("session '%v' already exists!", session.ID)
}
sessions[id] = session
sessions[session.ID] = session.SessionBase
return nil
}

func (b *InMemoryBackend) ListUser(username string) (list StoredSessionList, err error) {
func (b *InMemoryBackend) ListUser(username string) (list SessionList, err error) {
b.mutex.RLock()
defer b.mutex.RUnlock()

Expand All @@ -81,46 +81,46 @@ func (b *InMemoryBackend) ListUser(username string) (list StoredSessionList, err
}
for id, session := range sessions {
if _, revoked := b.revoked[id]; !revoked {
list = append(list, StoredSession{ID: id, Session: session})
list = append(list, Session{ID: id, SessionBase: session})
}
}
return
}

func (b *InMemoryBackend) Revoke(id ulid.ULID, session Session) error {
func (b *InMemoryBackend) Revoke(session Session) error {
b.mutex.Lock()
defer b.mutex.Unlock()

b.revoked[id] = session
b.revoked[session.ID] = session.SessionBase
return nil
}

func (b *InMemoryBackend) IsRevoked(id ulid.ULID) (bool, error) {
func (b *InMemoryBackend) IsRevoked(session Session) (bool, error) {
b.mutex.RLock()
defer b.mutex.RUnlock()

_, exists := b.revoked[id]
_, exists := b.revoked[session.ID]
return exists, nil
}

func (b *InMemoryBackend) ListRevoked() (list StoredSessionList, err error) {
func (b *InMemoryBackend) ListRevoked() (list SessionList, err error) {
b.mutex.RLock()
defer b.mutex.RUnlock()

for id, session := range b.revoked {
list = append(list, StoredSession{ID: id, Session: session})
list = append(list, Session{ID: id, SessionBase: session})
}
return
}

func (b *InMemoryBackend) LoadRevocations(list StoredSessionList) (cnt uint, err error) {
func (b *InMemoryBackend) LoadRevocations(list SessionList) (cnt uint, err error) {
b.mutex.Lock()
defer b.mutex.Unlock()

cnt = 0
for _, session := range list {
if _, exists := b.revoked[session.ID]; !exists {
b.revoked[session.ID] = session.Session
b.revoked[session.ID] = session.SessionBase
cnt = cnt + 1
}
}
Expand Down
67 changes: 26 additions & 41 deletions cookie/store.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,18 +83,13 @@ type SignerVerifier interface {
Verify(payload, signature []byte) error
}

type StoredSession struct {
ID ulid.ULID `json:"id"`
Session Session `josn:"session"`
}

type StoredSessionList []StoredSession
type SessionList []Session

func (l StoredSessionList) MarshalJSON() ([]byte, error) {
func (l SessionList) MarshalJSON() ([]byte, error) {
if len(l) == 0 {
return []byte("[]"), nil
}
var tmp []StoredSession = l
var tmp []Session = l
return json.Marshal(tmp)
}

Expand All @@ -104,12 +99,12 @@ type SignedRevocationList struct {
}

type StoreBackend interface {
Save(id ulid.ULID, session Session) error
ListUser(username string) (StoredSessionList, error)
Revoke(id ulid.ULID, session Session) error
IsRevoked(id ulid.ULID) (bool, error)
ListRevoked() (StoredSessionList, error)
LoadRevocations(StoredSessionList) (uint, error)
Save(session Session) error
ListUser(username string) (SessionList, error)
Revoke(session Session) error
IsRevoked(session Session) (bool, error)
ListRevoked() (SessionList, error)
LoadRevocations(SessionList) (uint, error)
CollectGarbage() (uint, error)
}

Expand Down Expand Up @@ -242,7 +237,7 @@ func (st *Store) syncRevocations(client *http.Client, syncBaseURL *url.URL, toke
return
}

var list StoredSessionList
var list SessionList
if err = json.Unmarshal(signed.Revoked, &list); err != nil {
st.infoLog.Printf("sync-store: error parsing sync response: %v", err)
return
Expand Down Expand Up @@ -332,12 +327,13 @@ func (st *Store) Options() (opts Options) {
return
}

func (st *Store) New(s Session) (value string, opts Options, err error) {
func (st *Store) New(username string) (value string, opts Options, err error) {
if st.signer == nil {
err = fmt.Errorf("no signing key loaded")
return
}

s := SessionBase{Username: username}
s.SetExpiry(st.conf.Expire)
id := ulid.Make()
var v *Value
Expand All @@ -348,7 +344,7 @@ func (st *Store) New(s Session) (value string, opts Options, err error) {
return
}

if err = st.backend.Save(id, s); err != nil {
if err = st.backend.Save(Session{ID: id, SessionBase: s}); err != nil {
return
}
st.dbgLog.Printf("successfully generated new session('%v'): %+v", id, s)
Expand All @@ -358,7 +354,7 @@ func (st *Store) New(s Session) (value string, opts Options, err error) {
return
}

func (st *Store) Verify(value string) (id string, s Session, err error) {
func (st *Store) Verify(value string) (s Session, err error) {
var v Value
if err = v.FromString(value); err != nil {
return
Expand All @@ -374,15 +370,17 @@ func (st *Store) Verify(value string) (id string, s Session, err error) {
return
}

var _id ulid.ULID
if _id, err = v.ID(); err != nil {
if s, err = v.Session(); err != nil {
err = fmt.Errorf("unable to decode cookie: %v", err)
return
}
id = _id.String()
if s.IsExpired() {
err = fmt.Errorf("cookie is expired")
return
}

var revoked bool
if revoked, err = st.backend.IsRevoked(_id); err != nil {
if revoked, err = st.backend.IsRevoked(s); err != nil {
err = fmt.Errorf("failed to check for cookie revocation: %v", err)
return
}
Expand All @@ -391,37 +389,24 @@ func (st *Store) Verify(value string) (id string, s Session, err error) {
return
}

if s, err = v.Session(); err != nil {
err = fmt.Errorf("unable to decode cookie: %v", err)
return
}
if s.IsExpired() {
err = fmt.Errorf("cookie is expired")
return
}

st.dbgLog.Printf("successfully verified session('%v'): %+v", id, s)
st.dbgLog.Printf("successfully verified session('%v'): %+v", s.ID, s.SessionBase)
return
}

func (st *Store) ListUser(username string) (StoredSessionList, error) {
func (st *Store) ListUser(username string) (SessionList, error) {
return st.backend.ListUser(username)
}

func (st *Store) Revoke(id string, session Session) error {
toRevoke, err := ulid.ParseStrict(id)
if err != nil {
return err
}
if err = st.backend.Revoke(toRevoke, session); err != nil {
func (st *Store) Revoke(session Session) error {
if err := st.backend.Revoke(session); err != nil {
return err
}
st.dbgLog.Printf("successfully revoked session('%v')", id)
st.dbgLog.Printf("successfully revoked session('%v')", session.ID)
return nil
}

func (st *Store) ListRevoked() (result SignedRevocationList, err error) {
var revoked StoredSessionList
var revoked SessionList
if revoked, err = st.backend.ListRevoked(); err != nil {
return
}
Expand Down
Loading

0 comments on commit baea797

Please sign in to comment.