From f2d1e2da50388b2cd0715cefd56183ab1f9e5491 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Emmanuel=20Beno=C3=AEt?= <tseeker@nocternity.net>
Date: Sun, 18 Sep 2022 10:06:19 +0200
Subject: [PATCH] Implemented vars clause

  * Can be used on any instruction
  * Variables are evaluated before condition is checked
---
 example/01-test-reconstructed.yml  |  2 +-
 inventory_plugins/reconstructed.py | 77 ++++++++++++++++++++++++++----
 2 files changed, 69 insertions(+), 10 deletions(-)

diff --git a/example/01-test-reconstructed.yml b/example/01-test-reconstructed.yml
index 077d24f..27b7ee1 100644
--- a/example/01-test-reconstructed.yml
+++ b/example/01-test-reconstructed.yml
@@ -95,7 +95,7 @@ instructions:
   # Component group. We add the host directly if there is no subcomponent.
   - when: inv__component is defined
     action: block
-    locals:
+    vars:
       comp_group: "svcm_{{ inv__service }}_{{ inv__component }}"
     block:
     - action: create_group
diff --git a/inventory_plugins/reconstructed.py b/inventory_plugins/reconstructed.py
index 5d4f9c2..3642a28 100644
--- a/inventory_plugins/reconstructed.py
+++ b/inventory_plugins/reconstructed.py
@@ -93,7 +93,7 @@ DOCUMENTATION = """
         default: host
 """
 
-INSTR_COMMON_FIELDS = ("when", "loop", "loop_var", "action", "run_once")
+INSTR_COMMON_FIELDS = ("action", "loop", "loop_var", "run_once", "vars", "when")
 """Fields that may be present on all instructions."""
 
 INSTR_OWN_FIELDS = {
@@ -229,11 +229,13 @@ class RcInstruction(abc.ABC):
         self._inventory = inventory
         self._templar = templar
         self._display = display
+        self._action = action
         self._condition = None
+        self._executed_once = None
         self._loop = None
         self._loop_var = None
-        self._action = action
-        self._executed_once = None
+        self._vars = None
+        self._save = None
 
     def __repr__(self):
         """Builds a compact debugging representation of the instruction, \
@@ -338,12 +340,52 @@ class RcInstruction(abc.ABC):
             raise AnsibleParserError(
                 "%s: 'loop_var' clause found without 'loop'" % (self._action,)
             )
+        # Extract local variables
+        self._vars = self.parse_vars(record)
+        # Cache the list of variables to save before execution
+        save = list(self._vars.keys())
+        if self._loop is not None:
+            save.append(self._loop_var)
+        self._save = tuple(save)
         # Handle instructions that may only be executed once
         if record.get("run_once", False):
             self._executed_once = False
         # Process action-specific fields
         self.parse_action(record)
 
+    def parse_vars(self, record):
+        """Parse local variable definitions from the record.
+
+        This method checks for a ``vars`` section in the YAML data, and extracts
+        it if it exists.
+
+        Args:
+            record: the YAML data for the instruction
+
+        Returns:
+            a dictionnary that contains the variable definitions
+
+        Raises:
+            AnsibleParserError: when the ``vars`` entry is invalid or contains \
+                    invalid definitions
+        """
+        if "vars" not in record:
+            return {}
+        if not isinstance(record["vars"], dict):
+            raise AnsibleParserError(
+                "%s: 'vars' should be a dictionnary" % (self._action,)
+            )
+        for k, v in record["vars"].items():
+            if not isinstance(k, string_types):
+                raise AnsibleParserError(
+                    "%s: vars identifiers must be strings" % (self._action,)
+                )
+            if not isidentifier(k):
+                raise AnsibleParserError(
+                    "%s: '%s' is not a valid identifier" % (self._action, k)
+                )
+        return record["vars"]
+
     def parse_group_name(self, record, name):
         """Parse a field containing the name of a group, or a template.
 
@@ -409,12 +451,13 @@ class RcInstruction(abc.ABC):
             return True
         if self._executed_once is False:
             self._executed_once = True
-        if self._loop is None:
-            self._display.vvvv("%s : running action %s" % (host_name, self._action))
-            return self.run_iteration(host_name, variables)
-        # Save previous loop variable state
-        variables._script_stack_push([self._loop_var])
+        # Save previous loop and local variables state
+        variables._script_stack_push(self._save)
         try:
+            # Instructions without loops
+            if self._loop is None:
+                self._display.vvvv("%s : running action %s" % (host_name, self._action))
+                return self.run_iteration(host_name, variables)
             # Loop over all values
             for value in self.evaluate_loop(host_name, variables):
                 self._display.vvvv(
@@ -440,6 +483,7 @@ class RcInstruction(abc.ABC):
             ``True`` if execution must continue, ``False`` if it must be
             interrupted
         """
+        self.compute_locals(host_name, variables)
         if self.evaluate_condition(host_name, variables):
             rv = self.execute_action(host_name, variables)
             if not rv:
@@ -466,7 +510,6 @@ class RcInstruction(abc.ABC):
         if self._condition is None:
             return True
         t = self._templar
-        t.available_variables = variables
         template = "%s%s%s" % (
             t.environment.variable_start_string,
             self._condition,
@@ -479,6 +522,22 @@ class RcInstruction(abc.ABC):
         )
         return rv
 
+    def compute_locals(self, host_name, variables):
+        """Compute local variables.
+
+        This method iterates through all local variable definitions and runs
+        them through the templar.
+
+        Args:
+            host_name: the name of the host the instruction is being executed for
+            variables: the variable storage instance
+        """
+        self._templar.available_variables = variables
+        for key, value in self._vars.items():
+            result = self._templar.template(value)
+            variables[key] = result
+            self._display.vvvv("- set local variable %s to %s" % (key, result))
+
     def evaluate_loop(self, host_name, variables):
         """Evaluate the values to iterate over when a ``loop`` is defined.