A somewhat usable and tested version.
[django-ssify.git] / ssify / variables.py
index 9ccc3f3..e969450 100644 (file)
@@ -1,3 +1,7 @@
+# -*- coding: utf-8 -*-
+# This file is part of django-ssify, licensed under GNU Affero GPLv3 or later.
+# Copyright © Fundacja Nowoczesna Polska. See README.md for more information.
+#
 """
 Utilities for defining SSI variables.
 
@@ -5,14 +9,16 @@ SSI variables are a way of providing values that need to be computed
 at request time to the prerendered templates.
 
 """
+from __future__ import unicode_literals
 from hashlib import md5
 from django import template
-from django.utils.encoding import force_unicode
+from django.utils.encoding import force_text, python_2_unicode_compatible
 from django.utils.functional import Promise
 from django.utils.safestring import mark_safe
-from .exceptions import SsiVarsDependencyCycleError, UndeclaredSsiRefError
+from .exceptions import SsiVarsDependencyCycleError
 
 
+@python_2_unicode_compatible
 class SsiVariable(object):
     """
     Represents a variable computed by a template tag with given arguments.
@@ -27,8 +33,6 @@ class SsiVariable(object):
     so the user never has to deal with it directly.
 
     """
-    ret_types = 'bool', 'int', 'unicode'
-
     def __init__(self, tagpath=None, args=None, kwargs=None, name=None):
         self.tagpath = tagpath
         self.args = list(args or [])
@@ -39,7 +43,7 @@ class SsiVariable(object):
     def name(self):
         """Variable name is a hash of its definition."""
         if self._name is None:
-            self._name = 'v' + md5(json_encode(self.definition)).hexdigest()
+            self._name = 'v' + md5(json_encode(self.definition).encode('ascii')).hexdigest()
         return self._name
 
     def rehash(self):
@@ -71,7 +75,7 @@ class SsiVariable(object):
         return template.get_library(taglib).tags[tagname].get_value(
             request, *self.args, **self.kwargs)
 
-    def __unicode__(self):
+    def __str__(self):
         return mark_safe("<!--#echo var='%s' encoding='none'-->" % self.name)
 
     def as_var(self):
@@ -91,6 +95,8 @@ class SsiExpect(object):
     """This class says: I want the real value of this variable here."""
     def __init__(self, name):
         self.name = name
+    def __repr__(self):
+        return "SsiExpect(%s)" % (self.name,)
 
 
 def ssi_expect(var, type_):
@@ -153,7 +159,7 @@ def ssi_set_statement(var, value):
         value = ''
     return "<!--#set var='%s' value='%s'-->" % (
         var,
-        force_unicode(value).replace(u'\\', u'\\\\').replace(u"'", u"\\'"))
+        force_text(value).replace('\\', '\\\\').replace("'", "\\'"))
 
 
 def provide_vars(request, ssi_vars):
@@ -162,50 +168,61 @@ def provide_vars(request, ssi_vars):
 
     The main purpose of this function is to by called by SsifyMiddleware.
     """
+    def resolve_expects(var):
+        if not hasattr(var, 'hash_dirty'):
+            var.hash_dirty = False
+
+        for i, arg in enumerate(var.args):
+            if isinstance(arg, SsiExpect):
+                var.args[i] = resolved[arg.name]
+                var.hash_dirty = True
+        for k, arg in var.kwargs.items():
+            if isinstance(arg, SsiExpect):
+                var.kwargs[k] = resolved[arg.name]
+                var.hash_dirty = True
+
+        for arg in var.args + list(var.kwargs.values()):
+            if isinstance(arg, SsiVariable):
+                var.hash_dirty = resolve_expects(arg) or var.hash_dirty
+
+        hash_dirty = var.hash_dirty
+        if var.hash_dirty:
+            # Rehash after calculating the SsiExpects with real
+            # values, because that's what the included views expect.
+            var.rehash()
+            var.hash_dirty = False
+
+        return hash_dirty
+
+    def resolve_args(var):
+        kwargs = {}
+        for k, arg in var.kwargs.items():
+            kwargs[k] = resolved[arg.name] if isinstance(arg, SsiVariable) else arg
+        new_var = SsiVariable(var.tagpath,
+            [resolved[arg.name] if isinstance(arg, SsiVariable) else arg for arg in var.args],
+            kwargs)
+        return new_var
+
     resolved = {}
-    queue = ssi_vars.items()
+    queue = list(ssi_vars.values())
+    
     unresolved_streak = 0
     while queue:
-        var_name, var = queue.pop(0)
-        hash_dirty = False
-        new_name = var_name
-
+        var = queue.pop(0)
         try:
-            for i, arg in enumerate(var.args):
-                if isinstance(arg, SsiExpect):
-                    var.args[i] = resolved[arg.name]
-                    hash_dirty = True
-            for k, arg in var.kwargs.items():
-                if isinstance(arg, SsiExpect):
-                    var.args[k] = resolved[arg.name]
-                    hash_dirty = True
-
-            if hash_dirty:
-                # Rehash after calculating the SsiExpects with real
-                # values, because that's what the included views expect.
-                new_name = var.rehash()
-
-            for i, arg in enumerate(var.args):
-                if isinstance(arg, SsiVariable):
-                    var.args[i] = resolved[arg.name]
-            for k, arg in var.kwargs.items():
-                if isinstance(arg, SsiVariable):
-                    var.args[k] = resolved[arg.name]
-
-        except KeyError:
-            queue.append((var_name, var))
+            resolve_expects(var)
+            rv = resolve_args(var)
+        except KeyError as e:
+            queue.append(var)
             unresolved_streak += 1
-            if unresolved_streak == len(queue):
-                if arg.name in ssi_vars:
-                    raise SsiVarsDependencyCycleError(queue)
-                else:
-                    raise UndeclaredSsiRefError(request, var, arg.name)
+            if unresolved_streak > len(queue):
+                raise SsiVarsDependencyCycleError(request, queue, resolved)
             continue
 
-        resolved[new_name] = var.get_value(request)
+        resolved[var.name] = rv.get_value(request)
         unresolved_streak = 0
 
-    output = u"".join(ssi_set_statement(var, value)
+    output = "".join(ssi_set_statement(var, value)
                       for (var, value) in resolved.items()
                       ).encode('utf-8')
     return output