"net"
"time"
+ "github.com/shazow/ssh-chat/set"
"github.com/shazow/ssh-chat/sshd"
"golang.org/x/crypto/ssh"
)
if key == nil {
return ""
}
- // FIXME: Is there a way to index pubkeys without marshal'ing them into strings?
+ // FIXME: Is there a better way to index pubkeys without marshal'ing them into strings?
return sshd.Fingerprint(key)
}
+func newAuthItem(key ssh.PublicKey) set.Item {
+ return set.StringItem(newAuthKey(key))
+}
+
// newAuthAddr returns a string from a net.Addr used to index the address the key in our lookup.
func newAuthAddr(addr net.Addr) string {
if addr == nil {
// Auth stores lookups for bans, whitelists, and ops. It implements the sshd.Auth interface.
type Auth struct {
- bannedAddr *Set
- banned *Set
- whitelist *Set
- ops *Set
+ bannedAddr *set.Set
+ banned *set.Set
+ whitelist *set.Set
+ ops *set.Set
}
// NewAuth creates a new empty Auth.
func NewAuth() *Auth {
return &Auth{
- bannedAddr: NewSet(),
- banned: NewSet(),
- whitelist: NewSet(),
- ops: NewSet(),
+ bannedAddr: set.New(),
+ banned: set.New(),
+ whitelist: set.New(),
+ ops: set.New(),
}
}
if key == nil {
return
}
- authkey := newAuthKey(key)
+ authItem := newAuthItem(key)
if d != 0 {
- a.ops.AddExpiring(authkey, d)
+ a.ops.Add(set.Expire(authItem, d))
} else {
- a.ops.Add(authkey)
+ a.ops.Add(authItem)
}
- logger.Debugf("Added to ops: %s (for %s)", authkey, d)
+ logger.Debugf("Added to ops: %s (for %s)", authItem.Key(), d)
}
// IsOp checks if a public key is an op.
if key == nil {
return
}
- authkey := newAuthKey(key)
+ authItem := newAuthItem(key)
if d != 0 {
- a.whitelist.AddExpiring(authkey, d)
+ a.whitelist.Add(set.Expire(authItem, d))
} else {
- a.whitelist.Add(authkey)
+ a.whitelist.Add(authItem)
}
- logger.Debugf("Added to whitelist: %s (for %s)", authkey, d)
+ logger.Debugf("Added to whitelist: %s (for %s)", authItem.Key(), d)
}
// Ban will set a public key as banned.
// BanFingerprint will set a public key fingerprint as banned.
func (a *Auth) BanFingerprint(authkey string, d time.Duration) {
+ authItem := set.StringItem(authkey)
if d != 0 {
- a.banned.AddExpiring(authkey, d)
+ a.banned.Add(set.Expire(authItem, d))
} else {
- a.banned.Add(authkey)
+ a.banned.Add(authItem)
}
- logger.Debugf("Added to banned: %s (for %s)", authkey, d)
+ logger.Debugf("Added to banned: %s (for %s)", authItem.Key(), d)
}
// Ban will set an IP address as banned.
func (a *Auth) BanAddr(addr net.Addr, d time.Duration) {
- key := newAuthAddr(addr)
+ authItem := set.StringItem(addr.String())
if d != 0 {
- a.bannedAddr.AddExpiring(key, d)
+ a.bannedAddr.Add(set.Expire(authItem, d))
} else {
- a.bannedAddr.Add(key)
+ a.bannedAddr.Add(authItem)
}
- logger.Debugf("Added to bannedAddr: %s (for %s)", key, d)
+ logger.Debugf("Added to bannedAddr: %s (for %s)", authItem.Key(), d)
}
+++ /dev/null
-package sshchat
-
-import (
- "sync"
- "time"
-)
-
-type expiringValue struct {
- time.Time
-}
-
-func (v expiringValue) Bool() bool {
- return time.Now().Before(v.Time)
-}
-
-type value struct{}
-
-func (v value) Bool() bool {
- return true
-}
-
-type setValue interface {
- Bool() bool
-}
-
-// Set with expire-able keys
-type Set struct {
- sync.Mutex
- lookup map[string]setValue
-}
-
-// NewSet creates a new set.
-func NewSet() *Set {
- return &Set{
- lookup: map[string]setValue{},
- }
-}
-
-// Len returns the size of the set right now.
-func (s *Set) Len() int {
- s.Lock()
- defer s.Unlock()
- return len(s.lookup)
-}
-
-// In checks if an item exists in this set.
-func (s *Set) In(key string) bool {
- s.Lock()
- v, ok := s.lookup[key]
- if ok && !v.Bool() {
- ok = false
- delete(s.lookup, key)
- }
- s.Unlock()
- return ok
-}
-
-// Add item to this set, replace if it exists.
-func (s *Set) Add(key string) {
- s.Lock()
- s.lookup[key] = value{}
- s.Unlock()
-}
-
-// Add item to this set, replace if it exists.
-func (s *Set) AddExpiring(key string, d time.Duration) time.Time {
- until := time.Now().Add(d)
- s.Lock()
- s.lookup[key] = expiringValue{until}
- s.Unlock()
- return until
-}
}
func (item StringItem) Value() interface{} {
- return string(item)
+ return true
}
func Expire(item Item, d time.Duration) Item {
// Returned when a nil item is added. Nil values are considered expired and invalid.
var ErrNil = errors.New("item value must not be nil")
+type IterFunc func(key string, item Item) error
+
type Set struct {
sync.RWMutex
lookup map[string]Item
// Each loops over every item while holding a read lock and applies fn to each
// element.
-func (s *Set) Each(fn func(item Item)) {
- cleanup := []string{}
+func (s *Set) Each(fn IterFunc) error {
s.RLock()
+ defer s.RUnlock()
for key, item := range s.lookup {
if item.Value() == nil {
- cleanup = append(cleanup, key)
+ defer s.cleanup(key)
continue
}
- fn(item)
- }
- s.RUnlock()
-
- if len(cleanup) == 0 {
- return
- }
- for _, key := range cleanup {
- s.cleanup(key)
+ if err := fn(key, item); err != nil {
+ // Abort early
+ return err
+ }
}
+ return nil
}
// ListPrefix returns a list of items with a prefix, normalized.
r := []Item{}
prefix = s.normalize(prefix)
- s.Each(func(item Item) {
- r = append(r, item)
+ s.Each(func(key string, item Item) error {
+ if strings.HasPrefix(key, prefix) {
+ r = append(r, item)
+ }
+ return nil
})
return r
t.Errorf("ExpiringItem a nanosec ago is not expiring")
}
- item = &ExpiringItem{nil, time.Now().Add(time.Minute * 2)}
+ item = &ExpiringItem{nil, time.Now().Add(time.Minute * 5)}
if item.Expired() {
t.Errorf("ExpiringItem in 2 minutes is expiring now")
}
- item = Expire(StringItem("bar"), time.Minute*2).(*ExpiringItem)
+ item = Expire(StringItem("bar"), time.Minute*5).(*ExpiringItem)
until := item.Time
- if !until.After(time.Now().Add(time.Minute*1)) || !until.Before(time.Now().Add(time.Minute*3)) {
+ if !until.After(time.Now().Add(time.Minute*4)) || !until.Before(time.Now().Add(time.Minute*6)) {
t.Errorf("until is not a minute after %s: %s", time.Now(), until)
}
if item.Value() == nil {
if s.Len() != 2 {
t.Error("not len 2 after set")
}
+ if err := s.Replace("bar", Expire(StringItem("quux"), time.Minute*5)); err != nil {
+ t.Fatalf("failed to add quux: %s", err)
+ }
- if err := s.Replace("bar", Expire(StringItem("bar"), time.Minute*5)); err != nil {
+ if err := s.Replace("quux", Expire(StringItem("bar"), time.Minute*5)); err != nil {
t.Fatalf("failed to add bar: %s", err)
}
- if !s.In("bar") {
- t.Error("failed to match before expiry")
+ if s.In("quux") {
+ t.Error("quux in set after replace")
+ }
+ if _, err := s.Get("bar"); err != nil {
+ t.Errorf("failed to get before expiry: %s", err)
+ }
+ if err := s.Add(StringItem("barbar")); err != nil {
+ t.Fatalf("failed to add barbar")
+ }
+ if _, err := s.Get("barbar"); err != nil {
+ t.Errorf("failed to get barbar: %s", err)
+ }
+ b := s.ListPrefix("b")
+ if len(b) != 2 || b[0].Key() != "bar" || b[1].Key() != "barbar" {
+ t.Errorf("b-prefix incorrect: %q", b)
+ }
+
+ if err := s.Remove("bar"); err != nil {
+ t.Fatalf("failed to remove: %s", err)
+ }
+ if s.Len() != 2 {
+ t.Error("not len 2 after remove")
+ }
+ s.Clear()
+ if s.Len() != 0 {
+ t.Error("not len 0 after clear")
}
}
+++ /dev/null
-package sshchat
-
-import (
- "testing"
- "time"
-)
-
-func TestSetExpiring(t *testing.T) {
- s := NewSet()
- if s.In("foo") {
- t.Error("Matched before set.")
- }
-
- s.Add("foo")
- if !s.In("foo") {
- t.Errorf("Not matched after set")
- }
- if s.Len() != 1 {
- t.Error("Not len 1 after set")
- }
-
- v := expiringValue{time.Now().Add(-time.Nanosecond * 1)}
- if v.Bool() {
- t.Errorf("expiringValue now is not expiring")
- }
-
- v = expiringValue{time.Now().Add(time.Minute * 2)}
- if !v.Bool() {
- t.Errorf("expiringValue in 2 minutes is expiring now")
- }
-
- until := s.AddExpiring("bar", time.Minute*2)
- if !until.After(time.Now().Add(time.Minute*1)) || !until.Before(time.Now().Add(time.Minute*3)) {
- t.Errorf("until is not a minute after %s: %s", time.Now(), until)
- }
- val, ok := s.lookup["bar"]
- if !ok {
- t.Errorf("bar not in lookup")
- }
- if !until.Equal(val.(expiringValue).Time) {
- t.Errorf("bar's until is not equal to the expected value")
- }
- if !val.Bool() {
- t.Errorf("bar expired immediately")
- }
-
- if !s.In("bar") {
- t.Errorf("Not matched after timed set")
- }
- if s.Len() != 2 {
- t.Error("Not len 2 after set")
- }
-
- s.AddExpiring("bar", time.Nanosecond*1)
- if s.In("bar") {
- t.Error("Matched after expired timer")
- }
-}