set: Improve coverage and cleanup. Switch sshchat package to use it.
authorAndrey Petrov <andrey.petrov@shazow.net>
Mon, 15 Aug 2016 01:03:16 +0000 (21:03 -0400)
committerAndrey Petrov <andrey.petrov@shazow.net>
Wed, 24 Aug 2016 17:54:20 +0000 (13:54 -0400)
auth.go
set.go [deleted file]
set/item.go
set/set.go
set/set_test.go
set_test.go [deleted file]

diff --git a/auth.go b/auth.go
index 7095260082a8a0c8a92172628934b55f1b9e3861..ced008b8663822b89b619a31fbfb88065c05fc5d 100644 (file)
--- a/auth.go
+++ b/auth.go
@@ -5,6 +5,7 @@ import (
        "net"
        "time"
 
+       "github.com/shazow/ssh-chat/set"
        "github.com/shazow/ssh-chat/sshd"
        "golang.org/x/crypto/ssh"
 )
@@ -20,10 +21,14 @@ func newAuthKey(key ssh.PublicKey) string {
        if key == nil {
                return ""
        }
-       // FIXME: Is there a way to index pubkeys without marshal'ing them into strings?
+       // FIXME: Is there a better way to index pubkeys without marshal'ing them into strings?
        return sshd.Fingerprint(key)
 }
 
+func newAuthItem(key ssh.PublicKey) set.Item {
+       return set.StringItem(newAuthKey(key))
+}
+
 // newAuthAddr returns a string from a net.Addr used to index the address the key in our lookup.
 func newAuthAddr(addr net.Addr) string {
        if addr == nil {
@@ -35,19 +40,19 @@ func newAuthAddr(addr net.Addr) string {
 
 // Auth stores lookups for bans, whitelists, and ops. It implements the sshd.Auth interface.
 type Auth struct {
-       bannedAddr *Set
-       banned     *Set
-       whitelist  *Set
-       ops        *Set
+       bannedAddr *set.Set
+       banned     *set.Set
+       whitelist  *set.Set
+       ops        *set.Set
 }
 
 // NewAuth creates a new empty Auth.
 func NewAuth() *Auth {
        return &Auth{
-               bannedAddr: NewSet(),
-               banned:     NewSet(),
-               whitelist:  NewSet(),
-               ops:        NewSet(),
+               bannedAddr: set.New(),
+               banned:     set.New(),
+               whitelist:  set.New(),
+               ops:        set.New(),
        }
 }
 
@@ -85,13 +90,13 @@ func (a *Auth) Op(key ssh.PublicKey, d time.Duration) {
        if key == nil {
                return
        }
-       authkey := newAuthKey(key)
+       authItem := newAuthItem(key)
        if d != 0 {
-               a.ops.AddExpiring(authkey, d)
+               a.ops.Add(set.Expire(authItem, d))
        } else {
-               a.ops.Add(authkey)
+               a.ops.Add(authItem)
        }
-       logger.Debugf("Added to ops: %s (for %s)", authkey, d)
+       logger.Debugf("Added to ops: %s (for %s)", authItem.Key(), d)
 }
 
 // IsOp checks if a public key is an op.
@@ -108,13 +113,13 @@ func (a *Auth) Whitelist(key ssh.PublicKey, d time.Duration) {
        if key == nil {
                return
        }
-       authkey := newAuthKey(key)
+       authItem := newAuthItem(key)
        if d != 0 {
-               a.whitelist.AddExpiring(authkey, d)
+               a.whitelist.Add(set.Expire(authItem, d))
        } else {
-               a.whitelist.Add(authkey)
+               a.whitelist.Add(authItem)
        }
-       logger.Debugf("Added to whitelist: %s (for %s)", authkey, d)
+       logger.Debugf("Added to whitelist: %s (for %s)", authItem.Key(), d)
 }
 
 // Ban will set a public key as banned.
@@ -127,21 +132,22 @@ func (a *Auth) Ban(key ssh.PublicKey, d time.Duration) {
 
 // BanFingerprint will set a public key fingerprint as banned.
 func (a *Auth) BanFingerprint(authkey string, d time.Duration) {
+       authItem := set.StringItem(authkey)
        if d != 0 {
-               a.banned.AddExpiring(authkey, d)
+               a.banned.Add(set.Expire(authItem, d))
        } else {
-               a.banned.Add(authkey)
+               a.banned.Add(authItem)
        }
-       logger.Debugf("Added to banned: %s (for %s)", authkey, d)
+       logger.Debugf("Added to banned: %s (for %s)", authItem.Key(), d)
 }
 
 // Ban will set an IP address as banned.
 func (a *Auth) BanAddr(addr net.Addr, d time.Duration) {
-       key := newAuthAddr(addr)
+       authItem := set.StringItem(addr.String())
        if d != 0 {
-               a.bannedAddr.AddExpiring(key, d)
+               a.bannedAddr.Add(set.Expire(authItem, d))
        } else {
-               a.bannedAddr.Add(key)
+               a.bannedAddr.Add(authItem)
        }
-       logger.Debugf("Added to bannedAddr: %s (for %s)", key, d)
+       logger.Debugf("Added to bannedAddr: %s (for %s)", authItem.Key(), d)
 }
diff --git a/set.go b/set.go
deleted file mode 100644 (file)
index 3e29a57..0000000
--- a/set.go
+++ /dev/null
@@ -1,72 +0,0 @@
-package sshchat
-
-import (
-       "sync"
-       "time"
-)
-
-type expiringValue struct {
-       time.Time
-}
-
-func (v expiringValue) Bool() bool {
-       return time.Now().Before(v.Time)
-}
-
-type value struct{}
-
-func (v value) Bool() bool {
-       return true
-}
-
-type setValue interface {
-       Bool() bool
-}
-
-// Set with expire-able keys
-type Set struct {
-       sync.Mutex
-       lookup map[string]setValue
-}
-
-// NewSet creates a new set.
-func NewSet() *Set {
-       return &Set{
-               lookup: map[string]setValue{},
-       }
-}
-
-// Len returns the size of the set right now.
-func (s *Set) Len() int {
-       s.Lock()
-       defer s.Unlock()
-       return len(s.lookup)
-}
-
-// In checks if an item exists in this set.
-func (s *Set) In(key string) bool {
-       s.Lock()
-       v, ok := s.lookup[key]
-       if ok && !v.Bool() {
-               ok = false
-               delete(s.lookup, key)
-       }
-       s.Unlock()
-       return ok
-}
-
-// Add item to this set, replace if it exists.
-func (s *Set) Add(key string) {
-       s.Lock()
-       s.lookup[key] = value{}
-       s.Unlock()
-}
-
-// Add item to this set, replace if it exists.
-func (s *Set) AddExpiring(key string, d time.Duration) time.Time {
-       until := time.Now().Add(d)
-       s.Lock()
-       s.lookup[key] = expiringValue{until}
-       s.Unlock()
-       return until
-}
index bcb2a1be79de6317f8ceae8829094981955f4c9a..a59fa131276ed388f82682b5b3e2959fe12ccf2b 100644 (file)
@@ -15,7 +15,7 @@ func (item StringItem) Key() string {
 }
 
 func (item StringItem) Value() interface{} {
-       return string(item)
+       return true
 }
 
 func Expire(item Item, d time.Duration) Item {
index 06ccec30b03600d8ab86abf7ea7e656e6622f7b1..1b489ad5019543155a5a6c0c9f54918b1e1b146d 100644 (file)
@@ -15,6 +15,8 @@ var ErrMissing = errors.New("item does not exist")
 // Returned when a nil item is added. Nil values are considered expired and invalid.
 var ErrNil = errors.New("item value must not be nil")
 
+type IterFunc func(key string, item Item) error
+
 type Set struct {
        sync.RWMutex
        lookup    map[string]Item
@@ -153,24 +155,20 @@ func (s *Set) Replace(oldKey string, item Item) error {
 
 // Each loops over every item while holding a read lock and applies fn to each
 // element.
-func (s *Set) Each(fn func(item Item)) {
-       cleanup := []string{}
+func (s *Set) Each(fn IterFunc) error {
        s.RLock()
+       defer s.RUnlock()
        for key, item := range s.lookup {
                if item.Value() == nil {
-                       cleanup = append(cleanup, key)
+                       defer s.cleanup(key)
                        continue
                }
-               fn(item)
-       }
-       s.RUnlock()
-
-       if len(cleanup) == 0 {
-               return
-       }
-       for _, key := range cleanup {
-               s.cleanup(key)
+               if err := fn(key, item); err != nil {
+                       // Abort early
+                       return err
+               }
        }
+       return nil
 }
 
 // ListPrefix returns a list of items with a prefix, normalized.
@@ -179,8 +177,11 @@ func (s *Set) ListPrefix(prefix string) []Item {
        r := []Item{}
        prefix = s.normalize(prefix)
 
-       s.Each(func(item Item) {
-               r = append(r, item)
+       s.Each(func(key string, item Item) error {
+               if strings.HasPrefix(key, prefix) {
+                       r = append(r, item)
+               }
+               return nil
        })
 
        return r
index 7b55dc8fcbd1af9dd572ebb0070e60e3c046b52c..f75192d87b589431d09d14e444246c3ca337000e 100644 (file)
@@ -26,14 +26,14 @@ func TestSetExpiring(t *testing.T) {
                t.Errorf("ExpiringItem a nanosec ago is not expiring")
        }
 
-       item = &ExpiringItem{nil, time.Now().Add(time.Minute * 2)}
+       item = &ExpiringItem{nil, time.Now().Add(time.Minute * 5)}
        if item.Expired() {
                t.Errorf("ExpiringItem in 2 minutes is expiring now")
        }
 
-       item = Expire(StringItem("bar"), time.Minute*2).(*ExpiringItem)
+       item = Expire(StringItem("bar"), time.Minute*5).(*ExpiringItem)
        until := item.Time
-       if !until.After(time.Now().Add(time.Minute*1)) || !until.Before(time.Now().Add(time.Minute*3)) {
+       if !until.After(time.Now().Add(time.Minute*4)) || !until.Before(time.Now().Add(time.Minute*6)) {
                t.Errorf("until is not a minute after %s: %s", time.Now(), until)
        }
        if item.Value() == nil {
@@ -54,11 +54,38 @@ func TestSetExpiring(t *testing.T) {
        if s.Len() != 2 {
                t.Error("not len 2 after set")
        }
+       if err := s.Replace("bar", Expire(StringItem("quux"), time.Minute*5)); err != nil {
+               t.Fatalf("failed to add quux: %s", err)
+       }
 
-       if err := s.Replace("bar", Expire(StringItem("bar"), time.Minute*5)); err != nil {
+       if err := s.Replace("quux", Expire(StringItem("bar"), time.Minute*5)); err != nil {
                t.Fatalf("failed to add bar: %s", err)
        }
-       if !s.In("bar") {
-               t.Error("failed to match before expiry")
+       if s.In("quux") {
+               t.Error("quux in set after replace")
+       }
+       if _, err := s.Get("bar"); err != nil {
+               t.Errorf("failed to get before expiry: %s", err)
+       }
+       if err := s.Add(StringItem("barbar")); err != nil {
+               t.Fatalf("failed to add barbar")
+       }
+       if _, err := s.Get("barbar"); err != nil {
+               t.Errorf("failed to get barbar: %s", err)
+       }
+       b := s.ListPrefix("b")
+       if len(b) != 2 || b[0].Key() != "bar" || b[1].Key() != "barbar" {
+               t.Errorf("b-prefix incorrect: %q", b)
+       }
+
+       if err := s.Remove("bar"); err != nil {
+               t.Fatalf("failed to remove: %s", err)
+       }
+       if s.Len() != 2 {
+               t.Error("not len 2 after remove")
+       }
+       s.Clear()
+       if s.Len() != 0 {
+               t.Error("not len 0 after clear")
        }
 }
diff --git a/set_test.go b/set_test.go
deleted file mode 100644 (file)
index 1d7fbef..0000000
+++ /dev/null
@@ -1,58 +0,0 @@
-package sshchat
-
-import (
-       "testing"
-       "time"
-)
-
-func TestSetExpiring(t *testing.T) {
-       s := NewSet()
-       if s.In("foo") {
-               t.Error("Matched before set.")
-       }
-
-       s.Add("foo")
-       if !s.In("foo") {
-               t.Errorf("Not matched after set")
-       }
-       if s.Len() != 1 {
-               t.Error("Not len 1 after set")
-       }
-
-       v := expiringValue{time.Now().Add(-time.Nanosecond * 1)}
-       if v.Bool() {
-               t.Errorf("expiringValue now is not expiring")
-       }
-
-       v = expiringValue{time.Now().Add(time.Minute * 2)}
-       if !v.Bool() {
-               t.Errorf("expiringValue in 2 minutes is expiring now")
-       }
-
-       until := s.AddExpiring("bar", time.Minute*2)
-       if !until.After(time.Now().Add(time.Minute*1)) || !until.Before(time.Now().Add(time.Minute*3)) {
-               t.Errorf("until is not a minute after %s: %s", time.Now(), until)
-       }
-       val, ok := s.lookup["bar"]
-       if !ok {
-               t.Errorf("bar not in lookup")
-       }
-       if !until.Equal(val.(expiringValue).Time) {
-               t.Errorf("bar's until is not equal to the expected value")
-       }
-       if !val.Bool() {
-               t.Errorf("bar expired immediately")
-       }
-
-       if !s.In("bar") {
-               t.Errorf("Not matched after timed set")
-       }
-       if s.Len() != 2 {
-               t.Error("Not len 2 after set")
-       }
-
-       s.AddExpiring("bar", time.Nanosecond*1)
-       if s.In("bar") {
-               t.Error("Matched after expired timer")
-       }
-}