Django 1.11 compatibility fixes.
[django-ssify.git] / ssify / variables.py
index 9ccc3f3..5e220ad 100644 (file)
@@ -1,3 +1,7 @@
+# -*- coding: utf-8 -*-
+# This file is part of django-ssify, licensed under GNU Affero GPLv3 or later.
+# Copyright © Fundacja Nowoczesna Polska. See README.md for more information.
+#
 """
 Utilities for defining SSI variables.
 
@@ -5,14 +9,31 @@ SSI variables are a way of providing values that need to be computed
 at request time to the prerendered templates.
 
 """
+from __future__ import unicode_literals
 from hashlib import md5
-from django import template
-from django.utils.encoding import force_unicode
+from django.template import Node
+
+try:
+    # Django < 1.9
+    from django.template.base import get_library
+except:
+    from importlib import import_module
+    from django.template.backends.django import get_installed_libraries
+
+    def get_library(taglib):
+        if not hasattr(get_library, 'libraries'):
+            get_library.libraries = get_installed_libraries()
+        if isinstance(get_library.libraries[taglib], str):
+            get_library.libraries[taglib] = import_module(get_library.libraries[taglib]).register
+        return get_library.libraries[taglib]
+
+from django.utils.encoding import force_text, python_2_unicode_compatible
 from django.utils.functional import Promise
 from django.utils.safestring import mark_safe
-from .exceptions import SsiVarsDependencyCycleError, UndeclaredSsiRefError
+from .exceptions import SsiVarsDependencyCycleError
 
 
+@python_2_unicode_compatible
 class SsiVariable(object):
     """
     Represents a variable computed by a template tag with given arguments.
@@ -27,8 +48,6 @@ class SsiVariable(object):
     so the user never has to deal with it directly.
 
     """
-    ret_types = 'bool', 'int', 'unicode'
-
     def __init__(self, tagpath=None, args=None, kwargs=None, name=None):
         self.tagpath = tagpath
         self.args = list(args or [])
@@ -39,7 +58,7 @@ class SsiVariable(object):
     def name(self):
         """Variable name is a hash of its definition."""
         if self._name is None:
-            self._name = 'v' + md5(json_encode(self.definition)).hexdigest()
+            self._name = 'v' + md5(json_encode(self.definition).encode('ascii')).hexdigest()
         return self._name
 
     def rehash(self):
@@ -68,10 +87,10 @@ class SsiVariable(object):
     def get_value(self, request):
         """Computes the real value of the variable, using the request."""
         taglib, tagname = self.tagpath.rsplit('.', 1)
-        return template.get_library(taglib).tags[tagname].get_value(
+        return get_library(taglib).tags[tagname].get_value(
             request, *self.args, **self.kwargs)
 
-    def __unicode__(self):
+    def __str__(self):
         return mark_safe("<!--#echo var='%s' encoding='none'-->" % self.name)
 
     def as_var(self):
@@ -91,6 +110,8 @@ class SsiExpect(object):
     """This class says: I want the real value of this variable here."""
     def __init__(self, name):
         self.name = name
+    def __repr__(self):
+        return "SsiExpect(%s)" % (self.name,)
 
 
 def ssi_expect(var, type_):
@@ -111,13 +132,13 @@ def ssi_expect(var, type_):
         return type_(var)
 
 
-class SsiVariableNode(template.Node):
+class SsiVariableNode(Node):
     """ Node for the SsiVariable tags. """
-    def __init__(self, tagpath, args, kwargs, vary=None, asvar=None):
+    def __init__(self, tagpath, args, kwargs, patch_response=None, asvar=None):
         self.tagpath = tagpath
         self.args = args
         self.kwargs = kwargs
-        self.vary = vary
+        self.patch_response = patch_response
         self.asvar = asvar
 
     def __repr__(self):
@@ -131,9 +152,13 @@ class SsiVariableNode(template.Node):
         var = SsiVariable(self.tagpath, resolved_args, resolved_kwargs)
 
         request = context['request']
+        if not hasattr(request, 'ssi_vars_needed'):
+            request.ssi_vars_needed = {}
         request.ssi_vars_needed[var.name] = var
-        if self.vary:
-            request.ssi_vary.update(self.vary)
+        if self.patch_response:
+            if not hasattr(request, 'ssi_patch_response'):
+                request.ssi_patch_response = []
+            request.ssi_patch_response.extend(self.patch_response)
 
         if self.asvar:
             context.dicts[0][self.asvar] = var
@@ -153,7 +178,7 @@ def ssi_set_statement(var, value):
         value = ''
     return "<!--#set var='%s' value='%s'-->" % (
         var,
-        force_unicode(value).replace(u'\\', u'\\\\').replace(u"'", u"\\'"))
+        force_text(value).replace('\\', '\\\\').replace("'", "\\'"))
 
 
 def provide_vars(request, ssi_vars):
@@ -162,50 +187,61 @@ def provide_vars(request, ssi_vars):
 
     The main purpose of this function is to by called by SsifyMiddleware.
     """
+    def resolve_expects(var):
+        if not hasattr(var, 'hash_dirty'):
+            var.hash_dirty = False
+
+        for i, arg in enumerate(var.args):
+            if isinstance(arg, SsiExpect):
+                var.args[i] = resolved[arg.name]
+                var.hash_dirty = True
+        for k, arg in var.kwargs.items():
+            if isinstance(arg, SsiExpect):
+                var.kwargs[k] = resolved[arg.name]
+                var.hash_dirty = True
+
+        for arg in var.args + list(var.kwargs.values()):
+            if isinstance(arg, SsiVariable):
+                var.hash_dirty = resolve_expects(arg) or var.hash_dirty
+
+        hash_dirty = var.hash_dirty
+        if var.hash_dirty:
+            # Rehash after calculating the SsiExpects with real
+            # values, because that's what the included views expect.
+            var.rehash()
+            var.hash_dirty = False
+
+        return hash_dirty
+
+    def resolve_args(var):
+        kwargs = {}
+        for k, arg in var.kwargs.items():
+            kwargs[k] = resolved[arg.name] if isinstance(arg, SsiVariable) else arg
+        new_var = SsiVariable(var.tagpath,
+            [resolved[arg.name] if isinstance(arg, SsiVariable) else arg for arg in var.args],
+            kwargs)
+        return new_var
+
     resolved = {}
-    queue = ssi_vars.items()
+    queue = list(ssi_vars.values())
+    
     unresolved_streak = 0
     while queue:
-        var_name, var = queue.pop(0)
-        hash_dirty = False
-        new_name = var_name
-
+        var = queue.pop(0)
         try:
-            for i, arg in enumerate(var.args):
-                if isinstance(arg, SsiExpect):
-                    var.args[i] = resolved[arg.name]
-                    hash_dirty = True
-            for k, arg in var.kwargs.items():
-                if isinstance(arg, SsiExpect):
-                    var.args[k] = resolved[arg.name]
-                    hash_dirty = True
-
-            if hash_dirty:
-                # Rehash after calculating the SsiExpects with real
-                # values, because that's what the included views expect.
-                new_name = var.rehash()
-
-            for i, arg in enumerate(var.args):
-                if isinstance(arg, SsiVariable):
-                    var.args[i] = resolved[arg.name]
-            for k, arg in var.kwargs.items():
-                if isinstance(arg, SsiVariable):
-                    var.args[k] = resolved[arg.name]
-
-        except KeyError:
-            queue.append((var_name, var))
+            resolve_expects(var)
+            rv = resolve_args(var)
+        except KeyError as e:
+            queue.append(var)
             unresolved_streak += 1
-            if unresolved_streak == len(queue):
-                if arg.name in ssi_vars:
-                    raise SsiVarsDependencyCycleError(queue)
-                else:
-                    raise UndeclaredSsiRefError(request, var, arg.name)
+            if unresolved_streak > len(queue):
+                raise SsiVarsDependencyCycleError(request, queue, resolved)
             continue
 
-        resolved[new_name] = var.get_value(request)
+        resolved[var.name] = rv.get_value(request)
         unresolved_streak = 0
 
-    output = u"".join(ssi_set_statement(var, value)
+    output = "".join(ssi_set_statement(var, value)
                       for (var, value) in resolved.items()
                       ).encode('utf-8')
     return output