Browse Source

Improve token managment

- Add remove token on logout
- Cleanup tokens on add/remove
Alexey Edelev 5 years ago
parent
commit
907abe3cdf
2 changed files with 60 additions and 9 deletions
  1. 59 9
      db/db.go
  2. 1 0
      web/server.go

+ 59 - 9
db/db.go

@@ -198,23 +198,23 @@ func (s *Storage) CheckUser(user, password string) error {
 }
 
 func (s *Storage) AddToken(user, token string) error {
-	log.Printf("add token: %s, %s", user, token)
+	log.Printf("Add token: %s\n", user)
 	s.tokensCollection.UpdateOne(context.Background(),
 		bson.M{"user": user},
 		bson.M{
 			"$addToSet": bson.M{
 				"token": bson.M{
 					"token":  token,
-					"expire": time.Now().Add(time.Hour * 96).Unix(),
+					"expire": time.Now().Add(time.Hour * 24).Unix(),
 				},
 			},
 		},
 		options.Update().SetUpsert(true))
+	s.CleanupTokens(user)
 	return nil
 }
 
 func (s *Storage) CheckToken(user, token string) error {
-	log.Printf("Check token: %s %s", user, token)
 	if token == "" {
 		return errors.New("Invalid token")
 	}
@@ -224,7 +224,6 @@ func (s *Storage) CheckToken(user, token string) error {
 			bson.M{"$match": bson.M{"user": user}},
 			bson.M{"$unwind": "$token"},
 			bson.M{"$match": bson.M{"token.token": token}},
-			bson.M{"$project": bson.M{"_id": 0, "token.expire": 1}},
 		})
 
 	if err != nil {
@@ -232,6 +231,7 @@ func (s *Storage) CheckToken(user, token string) error {
 		return err
 	}
 
+	ok := false
 	defer cur.Close(context.Background())
 	if cur.Next(context.Background()) {
 		result := struct {
@@ -242,15 +242,66 @@ func (s *Storage) CheckToken(user, token string) error {
 
 		err = cur.Decode(&result)
 
-		if err == nil && result.Token.Expire >= time.Now().Unix() {
-			log.Printf("Check token %s expire: %d", user, result.Token.Expire)
-			return nil
-		}
+		ok = err == nil && result.Token.Expire >= time.Now().Unix()
+	}
+
+	if ok {
+		//TODO: Renew token
+		return nil
 	}
 
 	return errors.New("Token expired")
 }
 
+func (s *Storage) RemoveToken(user, token string) error {
+	s.CleanupTokens(user)
+
+	_, err := s.tokensCollection.UpdateOne(context.Background(), bson.M{"user": user}, bson.M{"$pull": bson.M{"token": bson.M{"token": token}}})
+	if err != nil {
+		log.Printf("Unable to remove token %s", err)
+	}
+
+	return err
+}
+
+func (s *Storage) CleanupTokens(user string) {
+	log.Printf("Cleanup tokens: %s\n", user)
+
+	cur, err := s.tokensCollection.Aggregate(context.Background(),
+		bson.A{
+			bson.M{"$match": bson.M{"user": user}},
+			bson.M{"$unwind": "$token"},
+		})
+
+	if err != nil {
+		log.Fatalln(err)
+	}
+
+	type tokenMetadata struct {
+		Expire int64
+		Token  string
+	}
+
+	tokensToKeep := bson.A{}
+	defer cur.Close(context.Background())
+	for cur.Next(context.Background()) {
+		result := struct {
+			Token *tokenMetadata
+		}{
+			Token: &tokenMetadata{},
+		}
+
+		err = cur.Decode(&result)
+		if err == nil && result.Token.Expire >= time.Now().Unix() {
+			tokensToKeep = append(tokensToKeep, result.Token)
+		} else {
+			log.Printf("Expired token found for %s : %d", user, result.Token.Expire)
+		}
+	}
+
+	_, err = s.tokensCollection.UpdateOne(context.Background(), bson.M{"user": user}, bson.M{"$set": bson.M{"token": tokensToKeep}})
+	return
+}
 func (s *Storage) SaveMail(email, folder string, m *common.Mail) error {
 	result := &struct {
 		User string
@@ -408,7 +459,6 @@ func (s *Storage) GetUsers() (users []string, err error) {
 }
 
 func (s *Storage) GetEmails(user string) (emails []string, err error) {
-	fmt.Printf("user: %s\n", user)
 	result := &struct {
 		Email []string
 	}{}

+ 1 - 0
web/server.go

@@ -174,6 +174,7 @@ func (s *Server) logout(w http.ResponseWriter, r *http.Request) {
 	fmt.Println("logout")
 
 	session, _ := s.sessionStore.Get(r, CookieSessionToken)
+	s.storage.RemoveToken(session.Values["user"].(string), session.Values["token"].(string))
 	session.Values["user"] = ""
 	session.Values["token"] = ""
 	session.Save(r, w)