|
class Context(object): |
|
def __init__(self, **kwargs): |
|
self.__dict__.update(kwargs) |
|
|
|
|
|
class Fragment(object): |
|
def __init__(self, fragment): |
|
self.fragment = fragment |
|
|
|
def __str__(self): |
|
return self.fragment |
|
|
|
|
|
class GrapheneTestCase(object): |
|
|
|
def _cast_dict(self, d): |
|
"""Recursively cast and ordered dict to an unordered dict. |
|
|
|
This makes unittest errors more informative. |
|
""" |
|
if d is None: |
|
return d |
|
d = dict(d) |
|
for k, v in d.items(): |
|
if isinstance(v, collections.OrderedDict): |
|
d[k] = self._cast_dict(v) |
|
elif isinstance(v, list): |
|
d[k] = map(lambda x: self._cast_dict(x) if isinstance(x, collections.OrderedDict) else x, v) |
|
return d |
|
|
|
def _encode_arg(self, arg): |
|
if isinstance(arg, Fragment): |
|
return arg |
|
elif isinstance(arg, six.string_types): |
|
return '"%s"' % arg |
|
elif isinstance(arg, list): |
|
return '[%s]' % ', '.join(self._encode_arg(a) for a in arg) |
|
return arg |
|
|
|
def build_query(self, d, arguments=None): |
|
"""Build a GraphQL query fragment from the shape of the expected data.""" |
|
if arguments is None: |
|
arguments = {} |
|
query = [] |
|
for k, v in d.items(): |
|
if arguments.get(k): |
|
local_args = ['%s:%s' % (arg_k, self._encode_arg(arg_v)) for arg_k, arg_v in arguments[k].items()] |
|
k = '%s(%s)' % (k, ', '.join(local_args)) |
|
|
|
if isinstance(v, Fragment): |
|
subquery = v |
|
elif isinstance(v, dict): |
|
subquery = self.build_query(v, arguments) |
|
elif isinstance(v, list): |
|
subquery = self.build_query(v[0], arguments) |
|
else: |
|
subquery = '' |
|
|
|
if subquery: |
|
query.append('%s %s' % (k, subquery)) |
|
else: |
|
query.append(k) |
|
return '{ %s }' % ' '.join(query) |
|
|
|
def connection(self, ids): |
|
return { |
|
'edges': [{'node': {'id': id}} for id in ids] |
|
} |
|
|
|
def assertQuerySuccess(self, query, expected, context=None): |
|
result = schema.execute(query, context_value=context) |
|
return_data = self._cast_dict(result.data) |
|
if result.errors: |
|
msg = 'Query failed: %s' % result.errors |
|
else: |
|
msg = None |
|
self.assertEqual(return_data, expected, msg) |
|
|
|
def assertQueryComplete(self, expected, node, schema): |
|
available_node_keys = set(node.internal_type(schema).get_fields().keys()) |
|
queried_keys = set(expected.keys()) |
|
self.assertEqual( |
|
available_node_keys, |
|
queried_keys, |
|
'Test query did not cover all available fields.\nMissing: %s' % ( |
|
', '.join(available_node_keys - queried_keys) |
|
) |
|
) |