e96945072969a70757dafa0207789df268a4c78d
[django-ssify.git] / ssify / variables.py
1 # -*- coding: utf-8 -*-
2 # This file is part of django-ssify, licensed under GNU Affero GPLv3 or later.
3 # Copyright © Fundacja Nowoczesna Polska. See README.md for more information.
4 #
5 """
6 Utilities for defining SSI variables.
7
8 SSI variables are a way of providing values that need to be computed
9 at request time to the prerendered templates.
10
11 """
12 from __future__ import unicode_literals
13 from hashlib import md5
14 from django import template
15 from django.utils.encoding import force_text, python_2_unicode_compatible
16 from django.utils.functional import Promise
17 from django.utils.safestring import mark_safe
18 from .exceptions import SsiVarsDependencyCycleError
19
20
21 @python_2_unicode_compatible
22 class SsiVariable(object):
23     """
24     Represents a variable computed by a template tag with given arguments.
25
26     Instance of this class is returned from any template tag created
27     with `decorators.ssi_variable` decorator. If renders as SSI echo
28     statement, but you can also use it as an argument to {% ssi_include %},
29     to other ssi_variable, or create SSI if statements by using
30     its `if`, `else`, `endif` properties.
31
32     Variable's name, as used in SSI statements, is a hash of its definition,
33     so the user never has to deal with it directly.
34
35     """
36     def __init__(self, tagpath=None, args=None, kwargs=None, name=None):
37         self.tagpath = tagpath
38         self.args = list(args or [])
39         self.kwargs = kwargs or {}
40         self._name = name
41
42     @property
43     def name(self):
44         """Variable name is a hash of its definition."""
45         if self._name is None:
46             self._name = 'v' + md5(json_encode(self.definition).encode('ascii')).hexdigest()
47         return self._name
48
49     def rehash(self):
50         """
51         Sometimes there's a need to reset the variable name.
52
53         Typically, this is the case after finding real values for
54         variables passed as arguments to {% ssi_include %}.
55         """
56         self._name = None
57         return self.name
58
59     @property
60     def definition(self):
61         """Variable is defined by path to template tag and its arguments."""
62         if self.kwargs:
63             return self.tagpath, self.args, self.kwargs
64         elif self.args:
65             return self.tagpath, self.args
66         else:
67             return self.tagpath,
68
69     def __repr__(self):
70         return "SsiVariable(%s: %s)" % (self.name, repr(self.definition))
71
72     def get_value(self, request):
73         """Computes the real value of the variable, using the request."""
74         taglib, tagname = self.tagpath.rsplit('.', 1)
75         return template.get_library(taglib).tags[tagname].get_value(
76             request, *self.args, **self.kwargs)
77
78     def __str__(self):
79         return mark_safe("<!--#echo var='%s' encoding='none'-->" % self.name)
80
81     def as_var(self):
82         """Returns the form that can be used in SSI include's URL."""
83         return '${%s}' % self.name
84
85 # If-else-endif properties for use in templates.
86 setattr(SsiVariable, 'if',
87         lambda self: mark_safe("<!--#if expr='${%s}'-->" % self.name))
88 setattr(SsiVariable, 'else',
89         staticmethod(lambda: mark_safe("<!--#else-->")))
90 setattr(SsiVariable, 'endif',
91         staticmethod(lambda: mark_safe('<!--#endif-->')))
92
93
94 class SsiExpect(object):
95     """This class says: I want the real value of this variable here."""
96     def __init__(self, name):
97         self.name = name
98     def __repr__(self):
99         return "SsiExpect(%s)" % (self.name,)
100
101
102 def ssi_expect(var, type_):
103     """
104     Helper function for defining get_ssi_vars on ssi_included views.
105
106     The view needs a way of calculating all the needed variables from
107     the view args. But the args are probably the wrong type
108     (typically, str instead of int) or even are SsiVariables, not
109     resolved until request time.
110
111     This function provides a way to expect a real value of the needed type.
112
113     """
114     if isinstance(var, SsiVariable):
115         return SsiExpect(var.name)
116     else:
117         return type_(var)
118
119
120 class SsiVariableNode(template.Node):
121     """ Node for the SsiVariable tags. """
122     def __init__(self, tagpath, args, kwargs, vary=None, asvar=None):
123         self.tagpath = tagpath
124         self.args = args
125         self.kwargs = kwargs
126         self.vary = vary
127         self.asvar = asvar
128
129     def __repr__(self):
130         return "<SsiVariableNode>"
131
132     def render(self, context):
133         """Renders the tag as SSI echo or sets the context variable."""
134         resolved_args = [var.resolve(context) for var in self.args]
135         resolved_kwargs = dict((k, v.resolve(context))
136                                for k, v in self.kwargs.items())
137         var = SsiVariable(self.tagpath, resolved_args, resolved_kwargs)
138
139         request = context['request']
140         request.ssi_vars_needed[var.name] = var
141         if self.vary:
142             request.ssi_vary.update(self.vary)
143
144         if self.asvar:
145             context.dicts[0][self.asvar] = var
146             return ''
147         else:
148             return var
149
150
151 def ssi_set_statement(var, value):
152     """Generates an SSI set statement for a variable."""
153     if isinstance(value, Promise):
154         # Yes, this is quite brutal. But we need to know
155         # the real value now, we don't know the type,
156         # and we only want to evaluate the lazy function once.
157         value = value._proxy____cast()
158     if value is False or value is None:
159         value = ''
160     return "<!--#set var='%s' value='%s'-->" % (
161         var,
162         force_text(value).replace('\\', '\\\\').replace("'", "\\'"))
163
164
165 def provide_vars(request, ssi_vars):
166     """
167     Provides all the SSI set statements for ssi_vars variables.
168
169     The main purpose of this function is to by called by SsifyMiddleware.
170     """
171     def resolve_expects(var):
172         if not hasattr(var, 'hash_dirty'):
173             var.hash_dirty = False
174
175         for i, arg in enumerate(var.args):
176             if isinstance(arg, SsiExpect):
177                 var.args[i] = resolved[arg.name]
178                 var.hash_dirty = True
179         for k, arg in var.kwargs.items():
180             if isinstance(arg, SsiExpect):
181                 var.kwargs[k] = resolved[arg.name]
182                 var.hash_dirty = True
183
184         for arg in var.args + list(var.kwargs.values()):
185             if isinstance(arg, SsiVariable):
186                 var.hash_dirty = resolve_expects(arg) or var.hash_dirty
187
188         hash_dirty = var.hash_dirty
189         if var.hash_dirty:
190             # Rehash after calculating the SsiExpects with real
191             # values, because that's what the included views expect.
192             var.rehash()
193             var.hash_dirty = False
194
195         return hash_dirty
196
197     def resolve_args(var):
198         kwargs = {}
199         for k, arg in var.kwargs.items():
200             kwargs[k] = resolved[arg.name] if isinstance(arg, SsiVariable) else arg
201         new_var = SsiVariable(var.tagpath,
202             [resolved[arg.name] if isinstance(arg, SsiVariable) else arg for arg in var.args],
203             kwargs)
204         return new_var
205
206     resolved = {}
207     queue = list(ssi_vars.values())
208     
209     unresolved_streak = 0
210     while queue:
211         var = queue.pop(0)
212         try:
213             resolve_expects(var)
214             rv = resolve_args(var)
215         except KeyError as e:
216             queue.append(var)
217             unresolved_streak += 1
218             if unresolved_streak > len(queue):
219                 raise SsiVarsDependencyCycleError(request, queue, resolved)
220             continue
221
222         resolved[var.name] = rv.get_value(request)
223         unresolved_streak = 0
224
225     output = "".join(ssi_set_statement(var, value)
226                       for (var, value) in resolved.items()
227                       ).encode('utf-8')
228     return output
229
230
231 from .serializers import json_encode