d4701ab7e23c5de319a6fb995d0de4117a231e85
[wolnelektury.git] / src / wolnelektury / templatetags / switch_tag.py
1 # Source: http://djangosnippets.org/snippets/967/
2 # Author: adurdin
3 # Posted: August 13, 2008
4 #
5 #
6 # We can use it based on djangosnippets Terms of Service:
7 # (http://djangosnippets.org/about/tos/)
8 #
9 # 2. That you grant any third party who sees the code you post
10 # a royalty-free, non-exclusive license to copy and distribute that code
11 # and to make and distribute derivative works based on that code. You may
12 # include license terms in snippets you post, if you wish to use
13 # a particular license (such as the BSD license or GNU GPL), but that
14 # license must permit royalty-free copying, distribution and modification
15 # of the code to which it is applied.
16
17 from django import template
18 from django.template import Library, Node, VariableDoesNotExist
19
20 register = Library()
21
22
23 @register.tag(name="switch")
24 def do_switch(parser, token):
25     """
26     The ``{% switch %}`` tag compares a variable against one or more values in
27     ``{% case %}`` tags, and outputs the contents of the matching block.  An
28     optional ``{% else %}`` tag sets off the default output if no matches
29     could be found::
30
31         {% switch result_count %}
32             {% case 0 %}
33                 There are no search results.
34             {% case 1 %}
35                 There is one search result.
36             {% else %}
37                 Jackpot! Your search found {{ result_count }} results.
38         {% endswitch %}
39
40     Each ``{% case %}`` tag can take multiple values to compare the variable
41     against::
42
43         {% switch username %}
44             {% case "Jim" "Bob" "Joe" %}
45                 Me old mate {{ username }}! How ya doin?
46             {% else %}
47                 Hello {{ username }}
48         {% endswitch %}
49     """
50     bits = token.contents.split()
51     tag_name = bits[0]
52     if len(bits) != 2:
53         raise template.TemplateSyntaxError("'%s' tag requires one argument" % tag_name)
54     variable = parser.compile_filter(bits[1])
55
56     class BlockTagList(object):
57         # This is a bit of a hack, as it embeds knowledge of the behaviour
58         # of Parser.parse() relating to the "parse_until" argument.
59         def __init__(self, *names):
60             self.names = set(names)
61
62         def __contains__(self, token_contents):
63             name = token_contents.split()[0]
64             return name in self.names
65
66     # Skip over everything before the first {% case %} tag
67     parser.parse(BlockTagList('case', 'endswitch'))
68
69     cases = []
70     token = parser.next_token()
71     got_case = False
72     got_else = False
73     while token.contents != 'endswitch':
74         nodelist = parser.parse(BlockTagList('case', 'else', 'endswitch'))
75
76         if got_else:
77             raise template.TemplateSyntaxError("'else' must be last tag in '%s'." % tag_name)
78
79         contents = token.contents.split()
80         token_name, token_args = contents[0], contents[1:]
81
82         if token_name == 'case':
83             tests = map(parser.compile_filter, token_args)
84             case = (tests, nodelist)
85             got_case = True
86         else:
87             # The {% else %} tag
88             case = (None, nodelist)
89             got_else = True
90         cases.append(case)
91         token = parser.next_token()
92
93     if not got_case:
94         raise template.TemplateSyntaxError("'%s' must have at least one 'case'." % tag_name)
95
96     return SwitchNode(variable, cases)
97
98
99 class SwitchNode(Node):
100     def __init__(self, variable, cases):
101         self.variable = variable
102         self.cases = cases
103
104     def __repr__(self):
105         return "<Switch node>"
106
107     def __iter__(self):
108         for tests, nodelist in self.cases:
109             for node in nodelist:
110                 yield node
111
112     def get_nodes_by_type(self, nodetype):
113         nodes = []
114         if isinstance(self, nodetype):
115             nodes.append(self)
116         for tests, nodelist in self.cases:
117             nodes.extend(nodelist.get_nodes_by_type(nodetype))
118         return nodes
119
120     def render(self, context):
121         try:
122             value_missing = False
123             value = self.variable.resolve(context, True)
124         except VariableDoesNotExist:
125             no_value = True
126             value_missing = None
127             value = None
128
129         for tests, nodelist in self.cases:
130             if tests is None:
131                 return nodelist.render(context)
132             elif not value_missing:
133                 for test in tests:
134                     test_value = test.resolve(context, True)
135                     if value == test_value:
136                         return nodelist.render(context)
137         else:
138             return ""