]> git.aero2k.de Git - urlbot-v3.git/commitdiff
poll: dynamic amount of options
authorThorsten <mail@aero2k.de>
Sun, 8 Sep 2024 10:57:20 +0000 (12:57 +0200)
committerThorsten <mail@aero2k.de>
Sun, 8 Sep 2024 11:00:06 +0000 (13:00 +0200)
src/distbot/plugins/votepoll.py

index c4003f00557bd78b9e98b2f6ebfa56b9bc673377..0f421c4cfbf9329bea9ccfc84fa8b9d678373782 100644 (file)
@@ -17,8 +17,7 @@ logger = logging.getLogger(__name__)
 
 @dataclass
 class Poll:
-    option_a: str
-    option_b: str
+    options: list[str]
     votes: dict[str, str]
     ending_at: float
 
@@ -29,27 +28,33 @@ class Poll:
             self.votes[user] = vote
         return self
 
-    @property
-    def options(self):
-        return {self.option_a, self.option_b}
+    def option(self, option: str):
+        """
+        resolve lettered or actual option to the actual option
+        """
+        if option in self.options:
+            return option
+        elif option.isalpha() and len(option) == 1:
+            index = ord(option.lower()) - ord('a')
+            if 0 <= index < len(self.options):
+                return self.options[index]
 
     @property
     def key(self):
-        return self.generate_key(self.option_a, self.option_b)
+        return self.generate_key(self.options)
 
     @property
     def end_date(self):
         return datetime.fromtimestamp(self.ending_at)
 
-    def votes_a(self) -> int:
-        return len([vote for vote in self.votes.values() if vote == self.option_a])
-
-    def votes_b(self) -> int:
-        return len([vote for vote in self.votes.values() if vote == self.option_b])
+    def get_votes(self, option: str) -> int:
+        option = self.option(option)
+        if option:
+            return len(self.votes.get(option, []))
 
     @staticmethod
-    def generate_key(option_a, option_b):
-        return hashlib.sha256("".join(sorted((option_a, option_b))).encode()).hexdigest()
+    def generate_key(options: list[str]) -> str:
+        return hashlib.sha256("".join(sorted(options)).encode()).hexdigest()
 
     def due(self, timestamp=None) -> bool:
         if not timestamp:
@@ -59,12 +64,20 @@ class Poll:
     @staticmethod
     def from_json(data: dict):
         try:
-            return Poll(
-                option_a=data['option_a'],
-                option_b=data['option_b'],
-                votes=data['votes'],
-                ending_at=float(data['ending_at'])
-            )
+            if "option_a" in data:
+                return Poll(
+                    options=[data['option_a'], data['option_b']],
+                    votes=data['votes'],
+                    ending_at=float(data['ending_at'])
+                )
+            elif "options" in data:
+                return Poll(
+                    options=data['options'],
+                    votes=data['votes'],
+                    ending_at=float(data['ending_at'])
+                )
+            else:
+                raise ValueError(f"Can't parse option data")
         except (KeyError, TypeError, ValueError) as e:
             raise ValueError(f"Invalid poll data: {e}")
 
@@ -74,15 +87,21 @@ class Poll:
     def status_report(self):
         # Calculate winner and format result message
         winner = "It's a tie!"
-        if self.votes_a() > self.votes_b():
-            winner = f"**Winner:** {self.option_a}"
-        elif self.votes_a() < self.votes_b():
-            winner = f"**Winner:** {self.option_b}"
-
-        status_message = (f"**Poll status:**\n"
-                          f"  * A - {self.option_a}: {self.votes_a()}\n"
-                          f"  * B - {self.option_b}: {self.votes_b()}"
-                          ) + (f"\n{winner}" if self.due() else f"\nEnding at: {self.end_date}")
+        if self.votes:
+            max_votes = max(self.votes.values())
+            winners = [option for option, votes in self.votes.items() if votes == max_votes]
+        else:
+            winners = []
+
+        if len(winners) == 1:
+            winner = f"**Winner:** {winners[0]}"
+        elif len(winners) > 1:
+            winner = f"**Winners:** {', '.join(winners)}"
+
+        status_message = f"**Poll status:**\n"
+        for option in self.options:
+            status_message += f"  * {option}: {self.get_votes(option)}\n"
+        status_message += (f"\n{winner}" if self.due() else f"\nEnding at: {self.end_date}")
         return status_message
 
 
@@ -94,7 +113,7 @@ class VotePoll(Worker):
         - "active" key with the hashed key of the active poll
     """
     binding_keys = [
-        "nick.poll.*.vs.*", "nick.poll.*.vs.*.*",
+        "nick.poll.*.vs.*", "nick.poll.*.vs.*.#",
         "nick.vote.*",
         "nick.pollstatus",
         "nick.endpoll",
@@ -127,29 +146,23 @@ class VotePoll(Worker):
 
         if words[0] == "vote":
             # check if words contain any option
-            map_ab = {
-                "A": active_poll.option_a,
-                "B": active_poll.option_b,
-            }
-            vote = map_ab.get(words[1].upper(), words[1])
-
-            if vote in active_poll.options:
+            vote = active_poll.option(words[1])
+            if vote:
                 self.persist(active_poll.add_vote(sender, vote))
                 return Action(msg="Vote added.")
             else:
                 return Action(msg=f"not a valid option: {vote} (valid: {active_poll.options})")
         elif words[0] == "poll":
             # parse options
-            option_a = words[1]
-            option_b = words[3]
-            vote_duration = self.vote_duration
-            if len(words) == 5:
-                try:
-                    vote_duration = max(30, min(self.max_vote_duration, int(words[4])))
-                except ValueError as e:
-                    logger.exception("Failed parsing intended duration", exc_info=e)
+            options, vote_duration = self.parse_words_for_poll_options(words[1:])
+            if len(options) > 10:
+                return Action(msg="Too many options (limit: 10).")
+            elif len(set(options)) < len(options):
+                # duplicate options
+                return Action(msg="No duplicate options.")
+            vote_duration = self.vote_duration if not vote_duration else vote_duration
             # setup new poll
-            poll = Poll(option_a, option_b, {}, now + vote_duration)
+            poll = Poll(options, {}, now + vote_duration)
             # check prior results
             if active_poll:
                 return Action(msg="There is already an active poll.")
@@ -163,7 +176,12 @@ class VotePoll(Worker):
             self.start_poll(poll)
             # setup timeout to disable poll and present results
             bot_nick = conf_get("bot_nickname")
-            poll_message = f"**New Vote:** {sender} started a vote! Vote for A: {option_a} or B: {option_b} (reply with '{bot_nick}: vote foo' within {vote_duration}s)"
+
+            option_listing = "\n".join(" * {} - {}".format(chr(65 + i), option) for i, option in enumerate(options))
+            poll_message = (f"**New Vote:** {sender} started a vote! "
+                            f"Vote within {vote_duration}s "
+                            f"(reply with '{bot_nick}: vote A/B/foo' within {vote_duration}s)"
+                            f"\nOptions are:\n") + option_listing
             return Action(
                 msg=poll_message,
                 event=Action(time=now + vote_duration, command="nick.pollstatus",
@@ -175,10 +193,30 @@ class VotePoll(Worker):
             return self.end_poll()
         elif words[0] == "droppoll":
             sudoers = conf_get('sudoers') or []
-            poll_key = Poll.generate_key(words[1], words[3])
+            options, _ = self.parse_words_for_poll_options(words[1:])
+            poll_key = Poll.generate_key(options)
             if sender in sudoers:
                 return self.drop_poll(poll_key)
 
+    @staticmethod
+    def parse_words_for_poll_options(words):
+        options = []
+        vote_duration = None
+        for i, word in enumerate(words):
+            if i % 2 == 1 or i == len(words) - 1:
+                if word == "vs":
+                    options.append(words[i - 1])
+                elif i == len(words) - 1 and words[i-1] == "vs":
+                    options.append(word)
+                elif i == len(words) - 1 and word.isnumeric():
+                    options.append(words[i - 1])
+                    vote_duration = int(word)
+                elif i == len(words) - 1:
+                    options.append(words[i])
+                else:
+                    logger.warning("something wrong parsing options: %s", words)
+        return options, vote_duration
+
     def get_active_poll(self) -> Poll:
         active_key = self.db.get(self.KEY_ACTIVE)
         if active_key: