diff --git a/fbchat/graphql.py b/fbchat/graphql.py index f69089d..75c386a 100644 --- a/fbchat/graphql.py +++ b/fbchat/graphql.py @@ -133,7 +133,7 @@ def graphql_to_poll(a): title=a.get('title') if a.get('title') else a.get("text"), options=[graphql_to_poll_option(m) for m in a.get('options')] ) - rtn.uid = a.get("id") + rtn.uid = int(a["id"]) rtn.options_count = a.get("total_count") return rtn @@ -142,7 +142,7 @@ def graphql_to_poll_option(a): text=a.get('text'), vote=a.get('viewer_has_voted') == 'true' if isinstance(a.get('viewer_has_voted'), str) else a.get('viewer_has_voted') ) - rtn.uid = a.get('id') + rtn.uid = int(a["id"]) rtn.voters = [m.get('node').get('id') for m in a.get('voters').get('edges')] if isinstance(a.get('voters'), dict) else a.get('voters') rtn.votes_count = a.get('voters').get('count') if isinstance(a.get('voters'), dict) else a.get('total_count') return rtn diff --git a/tests/test_polls.py b/tests/test_polls.py new file mode 100644 index 0000000..ef53917 --- /dev/null +++ b/tests/test_polls.py @@ -0,0 +1,78 @@ +# -*- coding: utf-8 -*- + +from __future__ import unicode_literals + +import pytest + +from fbchat.models import Poll, PollOption, ThreadType +from utils import random_hex, subset + + +@pytest.fixture(scope="module", params=[ + Poll(title=random_hex(), options=[]), + Poll(title=random_hex(), options=[ + PollOption(random_hex(), vote=True), + PollOption(random_hex(), vote=True), + ]), + Poll(title=random_hex(), options=[ + PollOption(random_hex(), vote=False), + PollOption(random_hex(), vote=False), + ]), + Poll(title=random_hex(), options=[ + PollOption(random_hex(), vote=True), + PollOption(random_hex(), vote=True), + PollOption(random_hex(), vote=False), + PollOption(random_hex(), vote=False), + PollOption(random_hex()), + PollOption(random_hex()), + ]), + pytest.mark.xfail(Poll(title=None, options=[]), raises=ValueError), +]) +def poll_data(request, client1, group, catch_event): + with catch_event("onPollCreated") as x: + client1.createPoll(request.param, thread_id=group["id"]) + options = client1.fetchPollOptions(x.res["poll"].uid) + return x.res, request.param, options + + +def test_create_poll(client1, group, catch_event, poll_data): + event, poll, _ = poll_data + assert subset( + event, + author_id=client1.uid, + thread_id=group["id"], + thread_type=ThreadType.GROUP, + ) + assert subset(vars(event["poll"]), title=poll.title, options_count=len(poll.options)) + for recv_option in event["poll"].options: # The recieved options may not be the full list + old_option = list(filter(lambda o: o.text == recv_option.text, poll.options))[0] + voters = [client1.uid] if old_option.vote else [] + assert subset(vars(recv_option), voters=voters, votes_count=len(voters), vote=False) + + +def test_fetch_poll_options(client1, group, catch_event, poll_data): + _, poll, options = poll_data + assert len(options) == len(poll.options) + for option in options: + assert subset(vars(option)) + + +def test_update_poll_vote(client1, group, catch_event, poll_data): + event, poll, options = poll_data + new_vote_ids = [o.uid for o in options[0:len(options):2] if not o.vote] + re_vote_ids = [o.uid for o in options[0:len(options):2] if o.vote] + new_options = [random_hex(), random_hex()] + with catch_event("onPollVoted") as x: + client1.updatePollVote(event["poll"].uid, option_ids=new_vote_ids + re_vote_ids, new_options=new_options) + + assert subset( + x.res, + author_id=client1.uid, + thread_id=group["id"], + thread_type=ThreadType.GROUP, + ) + assert subset(vars(x.res["poll"]), title=poll.title, options_count=len(options + new_options)) + for o in new_vote_ids: + assert o in x.res["added_options"] + assert len(x.res["added_options"]) == len(new_vote_ids) + len(new_options) + assert set(x.res["removed_options"]) == set(o.uid for o in options if o.vote and o.uid not in re_vote_ids)