chat: Fix race conditions.
authorAndrey Petrov <andrey.petrov@shazow.net>
Tue, 12 Jul 2016 15:09:57 +0000 (11:09 -0400)
committerAndrey Petrov <andrey.petrov@shazow.net>
Tue, 12 Jul 2016 15:18:07 +0000 (11:18 -0400)
chat/room.go
chat/room_test.go
host.go

index 2d7a9832b8bfe0d0f37818044f018ab79c48ad19..bf2128c07bd4e75ee7dc75f4b211e7d196283ee5 100644 (file)
@@ -23,18 +23,19 @@ var ErrInvalidName = errors.New("invalid name")
 // Member is a User with per-Room metadata attached to it.
 type Member struct {
        *message.User
-       Op bool
 }
 
 // Room definition, also a Set of User Items
 type Room struct {
        topic     string
        history   *message.History
-       members   *idSet
        broadcast chan message.Message
        commands  Commands
        closed    bool
        closeOnce sync.Once
+
+       Members *idSet
+       Ops     *idSet
 }
 
 // NewRoom creates a new room.
@@ -44,8 +45,10 @@ func NewRoom() *Room {
        return &Room{
                broadcast: broadcast,
                history:   message.NewHistory(historyLen),
-               members:   newIdSet(),
                commands:  *defaultCommands,
+
+               Members: newIdSet(),
+               Ops:     newIdSet(),
        }
 }
 
@@ -58,10 +61,10 @@ func (r *Room) SetCommands(commands Commands) {
 func (r *Room) Close() {
        r.closeOnce.Do(func() {
                r.closed = true
-               r.members.Each(func(m identified) {
+               r.Members.Each(func(m identified) {
                        m.(*Member).Close()
                })
-               r.members.Clear()
+               r.Members.Clear()
                close(r.broadcast)
        })
 }
@@ -92,7 +95,7 @@ func (r *Room) HandleMsg(m message.Message) {
                }
 
                r.history.Add(m)
-               r.members.Each(func(u identified) {
+               r.Members.Each(func(u identified) {
                        user := u.(*Member).User
                        if skip && skipUser == user {
                                // Skip
@@ -137,23 +140,24 @@ func (r *Room) Join(u *message.User) (*Member, error) {
        if u.Id() == "" {
                return nil, ErrInvalidName
        }
-       member := Member{u, false}
-       err := r.members.Add(&member)
+       member := Member{u}
+       err := r.Members.Add(&member)
        if err != nil {
                return nil, err
        }
        r.History(u)
-       s := fmt.Sprintf("%s joined. (Connected: %d)", u.Name(), r.members.Len())
+       s := fmt.Sprintf("%s joined. (Connected: %d)", u.Name(), r.Members.Len())
        r.Send(message.NewAnnounceMsg(s))
        return &member, nil
 }
 
 // Leave the room as a user, will announce. Mostly used during setup.
 func (r *Room) Leave(u message.Identifier) error {
-       err := r.members.Remove(u)
+       err := r.Members.Remove(u)
        if err != nil {
                return err
        }
+       r.Ops.Remove(u)
        s := fmt.Sprintf("%s left.", u.Name())
        r.Send(message.NewAnnounceMsg(s))
        return nil
@@ -164,7 +168,7 @@ func (r *Room) Rename(oldId string, identity message.Identifier) error {
        if identity.Id() == "" {
                return ErrInvalidName
        }
-       err := r.members.Replace(oldId, identity)
+       err := r.Members.Replace(oldId, identity)
        if err != nil {
                return err
        }
@@ -189,7 +193,7 @@ func (r *Room) Member(u *message.User) (*Member, bool) {
 }
 
 func (r *Room) MemberById(id string) (*Member, bool) {
-       m, err := r.members.Get(id)
+       m, err := r.Members.Get(id)
        if err != nil {
                return nil, false
        }
@@ -198,8 +202,7 @@ func (r *Room) MemberById(id string) (*Member, bool) {
 
 // IsOp returns whether a user is an operator in this room.
 func (r *Room) IsOp(u *message.User) bool {
-       m, ok := r.Member(u)
-       return ok && m.Op
+       return r.Ops.In(u)
 }
 
 // Topic of the room.
@@ -215,7 +218,7 @@ func (r *Room) SetTopic(s string) {
 // NamesPrefix lists all members' names with a given prefix, used to query
 // for autocompletion purposes.
 func (r *Room) NamesPrefix(prefix string) []string {
-       members := r.members.ListPrefix(prefix)
+       members := r.Members.ListPrefix(prefix)
        names := make([]string, len(members))
        for i, u := range members {
                names[i] = u.(*Member).User.Name()
index 05fbf02d62ef2c72d5fe3a4b12b62c52cc7659f7..613535886a06f095c5cfcbd5d1ad3897d52d2b52 100644 (file)
@@ -161,10 +161,15 @@ func TestQuietToggleDisplayState(t *testing.T) {
                t.Fatal(err)
        }
 
-       // Drain the initial Join message
-       <-ch.broadcast
+       u.HandleMsg(<-u.ConsumeChan(), s)
+       expected = []byte(" * foo joined. (Connected: 1)" + message.Newline)
+       s.Read(&actual)
+       if !reflect.DeepEqual(actual, expected) {
+               t.Errorf("Got: `%s`; Expected: `%s`", actual, expected)
+       }
 
        ch.Send(message.ParseInput("/quiet", u))
+
        u.HandleMsg(<-u.ConsumeChan(), s)
        expected = []byte("-> Quiet mode is toggled ON" + message.Newline)
        s.Read(&actual)
@@ -173,9 +178,9 @@ func TestQuietToggleDisplayState(t *testing.T) {
        }
 
        ch.Send(message.ParseInput("/quiet", u))
+
        u.HandleMsg(<-u.ConsumeChan(), s)
        expected = []byte("-> Quiet mode is toggled OFF" + message.Newline)
-
        s.Read(&actual)
        if !reflect.DeepEqual(actual, expected) {
                t.Errorf("Got: `%s`; Expected: `%s`", actual, expected)
@@ -197,10 +202,15 @@ func TestRoomNames(t *testing.T) {
                t.Fatal(err)
        }
 
-       // Drain the initial Join message
-       <-ch.broadcast
+       u.HandleMsg(<-u.ConsumeChan(), s)
+       expected = []byte(" * foo joined. (Connected: 1)" + message.Newline)
+       s.Read(&actual)
+       if !reflect.DeepEqual(actual, expected) {
+               t.Errorf("Got: `%s`; Expected: `%s`", actual, expected)
+       }
 
        ch.Send(message.ParseInput("/names", u))
+
        u.HandleMsg(<-u.ConsumeChan(), s)
        expected = []byte("-> 1 connected: foo" + message.Newline)
        s.Read(&actual)
diff --git a/host.go b/host.go
index 20633007f0c4ea5e24c61b59bc454713fb5792d6..37c933ea5b111b975ad9528503339b8b1548b840 100644 (file)
--- a/host.go
+++ b/host.go
@@ -114,7 +114,9 @@ func (h *Host) Connect(term *sshd.Terminal) {
        h.count++
 
        // Should the user be op'd on join?
-       member.Op = h.isOp(term.Conn)
+       if h.isOp(term.Conn) {
+               h.Room.Ops.Add(member)
+       }
        ratelimit := rateio.NewSimpleLimiter(3, time.Second*3)
 
        for {
@@ -458,7 +460,7 @@ func (h *Host) InitCommands(c *chat.Commands) {
                        if !ok {
                                return errors.New("user not found")
                        }
-                       member.Op = true
+                       room.Ops.Add(member)
                        id := member.Identifier.(*Identity)
                        h.auth.Op(id.PublicKey(), until)