Browse code

Prevent data being truncated over persistent connection socket (#43885)

* Change how data is sent to the persistent connection socket.

We can't rely on readline(), so send the size of the data first. We can
then read that many bytes from the stream on the recieving end.

* Set pty to noncanonical mode before sending

* Now that we send data length, we don't need a sentinel anymore

* Copy socket changes to persistent, too

* Use os.write instead of fdopen()ing and using that.

* Follow pickle with sha1sum of pickle

* Swap order of vars and init being passed to ansible-connection

(cherry picked from commit f2211058826944a4e8e9e9b678d26615a29977b3)

Nathaniel Case authored on 2018/08/10 22:26:58
Showing 4 changed files
... ...
@@ -12,6 +12,7 @@ except Exception:
12 12
     pass
13 13
 
14 14
 import fcntl
15
+import hashlib
15 16
 import os
16 17
 import signal
17 18
 import socket
... ...
@@ -35,6 +36,23 @@ from ansible.utils.display import Display
35 35
 from ansible.utils.jsonrpc import JsonRpcServer
36 36
 
37 37
 
38
+def read_stream(byte_stream):
39
+    size = int(byte_stream.readline().strip())
40
+
41
+    data = byte_stream.read(size)
42
+    if len(data) < size:
43
+        raise Exception("EOF found before data was complete")
44
+
45
+    data_hash = to_text(byte_stream.readline().strip())
46
+    if data_hash != hashlib.sha1(data).hexdigest():
47
+        raise Exception("Read {0} bytes, but data did not match checksum".format(size))
48
+
49
+    # restore escaped loose \r characters
50
+    data = data.replace(br'\r', b'\r')
51
+
52
+    return data
53
+
54
+
38 55
 @contextmanager
39 56
 def file_lock(lock_path):
40 57
     """
... ...
@@ -192,25 +210,8 @@ def main():
192 192
 
193 193
     try:
194 194
         # read the play context data via stdin, which means depickling it
195
-        cur_line = stdin.readline()
196
-        init_data = b''
197
-
198
-        while cur_line.strip() != b'#END_INIT#':
199
-            if cur_line == b'':
200
-                raise Exception("EOF found before init data was complete")
201
-            init_data += cur_line
202
-            cur_line = stdin.readline()
203
-
204
-        cur_line = stdin.readline()
205
-        vars_data = b''
206
-
207
-        while cur_line.strip() != b'#END_VARS#':
208
-            if cur_line == b'':
209
-                raise Exception("EOF found before vars data was complete")
210
-            vars_data += cur_line
211
-            cur_line = stdin.readline()
212
-        # restore escaped loose \r characters
213
-        vars_data = vars_data.replace(br'\r', b'\r')
195
+        vars_data = read_stream(stdin)
196
+        init_data = read_stream(stdin)
214 197
 
215 198
         if PY3:
216 199
             pc_data = cPickle.loads(init_data, encoding='bytes')
... ...
@@ -10,14 +10,15 @@ import time
10 10
 import json
11 11
 import subprocess
12 12
 import sys
13
+import termios
13 14
 import traceback
14 15
 
15 16
 from ansible import constants as C
16 17
 from ansible.errors import AnsibleError, AnsibleParserError, AnsibleUndefinedVariable, AnsibleConnectionFailure, AnsibleActionFail, AnsibleActionSkip
17 18
 from ansible.executor.task_result import TaskResult
18 19
 from ansible.module_utils.six import iteritems, string_types, binary_type
19
-from ansible.module_utils.six.moves import cPickle
20 20
 from ansible.module_utils._text import to_text, to_native
21
+from ansible.module_utils.connection import write_to_file_descriptor
21 22
 from ansible.playbook.conditional import Conditional
22 23
 from ansible.playbook.task import Task
23 24
 from ansible.template import Templar
... ...
@@ -915,28 +916,24 @@ class TaskExecutor:
915 915
             [python, find_file_in_path('ansible-connection'), to_text(os.getppid())],
916 916
             stdin=slave, stdout=subprocess.PIPE, stderr=subprocess.PIPE
917 917
         )
918
-        stdin = os.fdopen(master, 'wb', 0)
919 918
         os.close(slave)
920 919
 
921
-        # Need to force a protocol that is compatible with both py2 and py3.
922
-        # That would be protocol=2 or less.
923
-        # Also need to force a protocol that excludes certain control chars as
924
-        # stdin in this case is a pty and control chars will cause problems.
925
-        # that means only protocol=0 will work.
926
-        src = cPickle.dumps(self._play_context.serialize(), protocol=0)
927
-        stdin.write(src)
928
-        stdin.write(b'\n#END_INIT#\n')
929
-
930
-        src = cPickle.dumps(variables, protocol=0)
931
-        # remaining \r fail to round-trip the socket
932
-        src = src.replace(b'\r', br'\r')
933
-        stdin.write(src)
934
-        stdin.write(b'\n#END_VARS#\n')
935
-
936
-        stdin.flush()
937
-
938
-        (stdout, stderr) = p.communicate()
939
-        stdin.close()
920
+        # We need to set the pty into noncanonical mode. This ensures that we
921
+        # can receive lines longer than 4095 characters (plus newline) without
922
+        # truncating.
923
+        old = termios.tcgetattr(master)
924
+        new = termios.tcgetattr(master)
925
+        new[3] = new[3] & ~termios.ICANON
926
+
927
+        try:
928
+            termios.tcsetattr(master, termios.TCSANOW, new)
929
+            write_to_file_descriptor(master, variables)
930
+            write_to_file_descriptor(master, self._play_context.serialize())
931
+
932
+            (stdout, stderr) = p.communicate()
933
+        finally:
934
+            termios.tcsetattr(master, termios.TCSANOW, old)
935
+        os.close(master)
940 936
 
941 937
         if p.returncode == 0:
942 938
             result = json.loads(to_text(stdout, errors='surrogate_then_replace'))
... ...
@@ -27,6 +27,7 @@
27 27
 # USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
28 28
 
29 29
 import os
30
+import hashlib
30 31
 import json
31 32
 import socket
32 33
 import struct
... ...
@@ -36,6 +37,30 @@ import uuid
36 36
 from functools import partial
37 37
 from ansible.module_utils._text import to_bytes, to_text
38 38
 from ansible.module_utils.six import iteritems
39
+from ansible.module_utils.six.moves import cPickle
40
+
41
+
42
+def write_to_file_descriptor(fd, obj):
43
+    """Handles making sure all data is properly written to file descriptor fd.
44
+
45
+    In particular, that data is encoded in a character stream-friendly way and
46
+    that all data gets written before returning.
47
+    """
48
+    # Need to force a protocol that is compatible with both py2 and py3.
49
+    # That would be protocol=2 or less.
50
+    # Also need to force a protocol that excludes certain control chars as
51
+    # stdin in this case is a pty and control chars will cause problems.
52
+    # that means only protocol=0 will work.
53
+    src = cPickle.dumps(obj, protocol=0)
54
+
55
+    # raw \r characters will not survive pty round-trip
56
+    # They should be rehydrated on the receiving end
57
+    src = src.replace(b'\r', br'\r')
58
+    data_hash = to_bytes(hashlib.sha1(src).hexdigest())
59
+
60
+    os.write(fd, b'%d\n' % len(src))
61
+    os.write(fd, src)
62
+    os.write(fd, b'%s\n' % data_hash)
39 63
 
40 64
 
41 65
 def send_data(s, data):
... ...
@@ -32,12 +32,12 @@ import pty
32 32
 import json
33 33
 import subprocess
34 34
 import sys
35
+import termios
35 36
 
36 37
 from ansible import constants as C
37 38
 from ansible.plugins.connection import ConnectionBase
38 39
 from ansible.module_utils._text import to_text
39
-from ansible.module_utils.six.moves import cPickle
40
-from ansible.module_utils.connection import Connection as SocketConnection
40
+from ansible.module_utils.connection import Connection as SocketConnection, write_to_file_descriptor
41 41
 from ansible.errors import AnsibleError
42 42
 
43 43
 try:
... ...
@@ -107,26 +107,24 @@ class Connection(ConnectionBase):
107 107
             [python, find_file_in_path('ansible-connection'), to_text(os.getppid())],
108 108
             stdin=slave, stdout=subprocess.PIPE, stderr=subprocess.PIPE
109 109
         )
110
-        stdin = os.fdopen(master, 'wb', 0)
111 110
         os.close(slave)
112 111
 
113
-        # Need to force a protocol that is compatible with both py2 and py3.
114
-        # That would be protocol=2 or less.
115
-        # Also need to force a protocol that excludes certain control chars as
116
-        # stdin in this case is a pty and control chars will cause problems.
117
-        # that means only protocol=0 will work.
118
-        src = cPickle.dumps(self._play_context.serialize(), protocol=0)
119
-        stdin.write(src)
120
-        stdin.write(b'\n#END_INIT#\n')
121
-
122
-        src = cPickle.dumps({}, protocol=0)
123
-        stdin.write(src)
124
-        stdin.write(b'\n#END_VARS#\n')
125
-
126
-        stdin.flush()
127
-
128
-        (stdout, stderr) = p.communicate()
129
-        stdin.close()
112
+        # We need to set the pty into noncanonical mode. This ensures that we
113
+        # can receive lines longer than 4095 characters (plus newline) without
114
+        # truncating.
115
+        old = termios.tcgetattr(master)
116
+        new = termios.tcgetattr(master)
117
+        new[3] = new[3] & ~termios.ICANON
118
+
119
+        try:
120
+            termios.tcsetattr(master, termios.TCSANOW, new)
121
+            write_to_file_descriptor(master, {'ansible_command_timeout': self.get_option('persistent_command_timeout')})
122
+            write_to_file_descriptor(master, self._play_context.serialize())
123
+
124
+            (stdout, stderr) = p.communicate()
125
+        finally:
126
+            termios.tcsetattr(master, termios.TCSANOW, old)
127
+        os.close(master)
130 128
 
131 129
         if p.returncode == 0:
132 130
             result = json.loads(to_text(stdout, errors='surrogate_then_replace'))