bin/ansible-connection
26ec2ecf
 #!/usr/bin/env python
9c0275a8
 # Copyright: (c) 2017, Ansible Project
 # GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
26ec2ecf
 from __future__ import (absolute_import, division, print_function)
 
3f949358
 __metaclass__ = type
26ec2ecf
 __requires__ = ['ansible']
3f949358
 
26ec2ecf
 try:
     import pkg_resources
 except Exception:
     pass
 
 import fcntl
f2211058
 import hashlib
26ec2ecf
 import os
 import signal
 import socket
 import sys
3f3101df
 import time
26ec2ecf
 import traceback
62159228
 import errno
9c0275a8
 import json
26ec2ecf
 
7ce9968c
 from contextlib import contextmanager
 
26ec2ecf
 from ansible import constants as C
d03b9edd
 from ansible.module_utils._text import to_bytes, to_text
d834412e
 from ansible.module_utils.six import PY3
90cd87f9
 from ansible.module_utils.six.moves import cPickle, StringIO
a1517234
 from ansible.module_utils.connection import Connection, ConnectionError, send_data, recv_data
9c0275a8
 from ansible.module_utils.service import fork_process
26ec2ecf
 from ansible.playbook.play_context import PlayContext
f9213694
 from ansible.plugins.loader import connection_loader
26ec2ecf
 from ansible.utils.path import unfrackpath, makedirs_safe
e20ed8bc
 from ansible.utils.display import Display
9c0275a8
 from ansible.utils.jsonrpc import JsonRpcServer
26ec2ecf
 
6fe9a5e4
 
f2211058
 def read_stream(byte_stream):
     size = int(byte_stream.readline().strip())
 
     data = byte_stream.read(size)
     if len(data) < size:
         raise Exception("EOF found before data was complete")
 
     data_hash = to_text(byte_stream.readline().strip())
     if data_hash != hashlib.sha1(data).hexdigest():
         raise Exception("Read {0} bytes, but data did not match checksum".format(size))
 
     # restore escaped loose \r characters
     data = data.replace(br'\r', b'\r')
 
     return data
 
 
7ce9968c
 @contextmanager
 def file_lock(lock_path):
     """
     Uses contextmanager to create and release a file lock based on the
     given path. This allows us to create locks using `with file_lock()`
     to prevent deadlocks related to failure to unlock properly.
     """
 
     lock_fd = os.open(lock_path, os.O_RDWR | os.O_CREAT, 0o600)
     fcntl.lockf(lock_fd, fcntl.LOCK_EX)
     yield
     fcntl.lockf(lock_fd, fcntl.LOCK_UN)
     os.close(lock_fd)
 
 
9c0275a8
 class ConnectionProcess(object):
26ec2ecf
     '''
9c0275a8
     The connection process wraps around a Connection object that manages
     the connection to a remote device that persists over the playbook
26ec2ecf
     '''
2f932d87
     def __init__(self, fd, play_context, socket_path, original_path, ansible_playbook_pid=None):
26ec2ecf
         self.play_context = play_context
9c0275a8
         self.socket_path = socket_path
         self.original_path = original_path
ed7cace4
 
9c0275a8
         self.fd = fd
         self.exception = None
6fe9a5e4
 
9c0275a8
         self.srv = JsonRpcServer()
         self.sock = None
eed24079
 
69575e25
         self.connection = None
2f932d87
         self._ansible_playbook_pid = ansible_playbook_pid
69575e25
 
483df136
     def start(self, variables):
9c0275a8
         try:
             messages = list()
             result = {}
 
             messages.append('control socket path is %s' % self.socket_path)
 
             # If this is a relative path (~ gets expanded later) then plug the
             # key's path on to the directory we originally came from, so we can
             # find it now that our cwd is /
             if self.play_context.private_key_file and self.play_context.private_key_file[0] not in '~/':
                 self.play_context.private_key_file = os.path.join(self.original_path, self.play_context.private_key_file)
2f932d87
             self.connection = connection_loader.get(self.play_context.connection, self.play_context, '/dev/null',
                                                     ansible_playbook_pid=self._ansible_playbook_pid)
483df136
             self.connection.set_options(var_options=variables)
9c0275a8
             self.connection._connect()
ea6ef3fa
 
69575e25
             self.connection._socket_path = self.socket_path
9c0275a8
             self.srv.register(self.connection)
527fc492
             messages.extend(sys.stdout.getvalue().splitlines())
9c0275a8
             messages.append('connection to remote device started successfully')
 
             self.sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
             self.sock.bind(self.socket_path)
             self.sock.listen(1)
             messages.append('local domain socket listeners started successfully')
         except Exception as exc:
             result['error'] = to_text(exc)
             result['exception'] = traceback.format_exc()
         finally:
             result['messages'] = messages
             self.fd.write(json.dumps(result))
             self.fd.close()
26ec2ecf
 
     def run(self):
         try:
69575e25
             while self.connection.connected:
62159228
                 signal.signal(signal.SIGALRM, self.connect_timeout)
                 signal.signal(signal.SIGTERM, self.handler)
62e1c14e
                 signal.alarm(self.connection.get_option('persistent_connect_timeout'))
62159228
 
9c0275a8
                 self.exception = None
                 (s, addr) = self.sock.accept()
62159228
                 signal.alarm(0)
26ec2ecf
 
9c0275a8
                 signal.signal(signal.SIGALRM, self.command_timeout)
26ec2ecf
                 while True:
                     data = recv_data(s)
                     if not data:
                         break
 
97d4e531
                     signal.alarm(self.connection.get_option('persistent_command_timeout'))
9c0275a8
                     resp = self.srv.handle_request(data)
6fe9a5e4
                     signal.alarm(0)
 
9c0275a8
                     send_data(s, to_bytes(resp))
62159228
 
26ec2ecf
                 s.close()
62159228
 
26ec2ecf
         except Exception as e:
62159228
             # socket.accept() will raise EINTR if the socket.close() is called
9c0275a8
             if hasattr(e, 'errno'):
                 if e.errno != errno.EINTR:
                     self.exception = traceback.format_exc()
             else:
                 self.exception = traceback.format_exc()
62159228
 
26ec2ecf
         finally:
3f3101df
             # allow time for any exception msg send over socket to receive at other end before shutting down
             time.sleep(0.1)
 
             # when done, close the connection properly and cleanup the socket file so it can be recreated
62159228
             self.shutdown()
 
     def connect_timeout(self, signum, frame):
3f3101df
         msg = 'persistent connection idle timeout triggered, timeout value is %s secs.\nSee the timeout setting options in the Network Debug and ' \
               'Troubleshooting Guide.' % self.connection.get_option('persistent_connect_timeout')
         display.display(msg, log_only=True)
         raise Exception(msg)
62159228
 
     def command_timeout(self, signum, frame):
3f3101df
         msg = 'command timeout triggered, timeout value is %s secs.\nSee the timeout setting options in the Network Debug and Troubleshooting Guide.'\
               % self.connection.get_option('persistent_command_timeout')
         display.display(msg, log_only=True)
         raise Exception(msg)
62159228
 
     def handler(self, signum, frame):
3f3101df
         msg = 'signal handler called with signal %s.' % signum
         display.display(msg, log_only=True)
         raise Exception(msg)
62159228
 
     def shutdown(self):
9c0275a8
         """ Shuts down the local domain socket
         """
69575e25
         if os.path.exists(self.socket_path):
             try:
                 if self.sock:
                     self.sock.close()
                 if self.connection:
                     self.connection.close()
             except:
                 pass
             finally:
                 if os.path.exists(self.socket_path):
                     os.remove(self.socket_path)
                     setattr(self.connection, '_socket_path', None)
                     setattr(self.connection, '_connected', False)
62159228
         display.display('shutdown complete', log_only=True)
 
8e0b5800
 
26ec2ecf
 def main():
9c0275a8
     """ Called to initiate the connect to the remote device
     """
     rc = 0
     result = {}
     messages = list()
     socket_path = None
 
d834412e
     # Need stdin as a byte stream
     if PY3:
         stdin = sys.stdin.buffer
     else:
         stdin = sys.stdin
3f949358
 
90cd87f9
     # Note: update the below log capture code after Display.display() is refactored.
     saved_stdout = sys.stdout
     sys.stdout = StringIO()
 
26ec2ecf
     try:
         # read the play context data via stdin, which means depickling it
f2211058
         vars_data = read_stream(stdin)
         init_data = read_stream(stdin)
483df136
 
84a59e47
         if PY3:
             pc_data = cPickle.loads(init_data, encoding='bytes')
483df136
             variables = cPickle.loads(vars_data, encoding='bytes')
84a59e47
         else:
             pc_data = cPickle.loads(init_data)
483df136
             variables = cPickle.loads(vars_data)
26ec2ecf
 
9c0275a8
         play_context = PlayContext()
         play_context.deserialize(pc_data)
527fc492
         display.verbosity = play_context.verbosity
62159228
 
26ec2ecf
     except Exception as e:
9c0275a8
         rc = 1
         result.update({
             'error': to_text(e),
             'exception': traceback.format_exc()
         })
 
     if rc == 0:
         ssh = connection_loader.get('ssh', class_only=True)
2f932d87
         ansible_playbook_pid = sys.argv[1]
         cp = ssh._create_control_path(play_context.remote_addr, play_context.port, play_context.remote_user, play_context.connection, ansible_playbook_pid)
9c0275a8
 
         # create the persistent connection dir if need be and create the paths
         # which we will be using later
         tmp_path = unfrackpath(C.PERSISTENT_CONTROL_PATH_DIR)
         makedirs_safe(tmp_path)
 
7ce9968c
         lock_path = unfrackpath("%s/.ansible_pc_lock_%s" % (tmp_path, play_context.remote_addr))
9c0275a8
         socket_path = unfrackpath(cp % dict(directory=tmp_path))
 
7ce9968c
         with file_lock(lock_path):
             if not os.path.exists(socket_path):
                 messages.append('local domain socket does not exist, starting it')
                 original_path = os.getcwd()
                 r, w = os.pipe()
                 pid = fork_process()
9c0275a8
 
7ce9968c
                 if pid == 0:
                     try:
                         os.close(r)
                         wfd = os.fdopen(w, 'w')
                         process = ConnectionProcess(wfd, play_context, socket_path, original_path, ansible_playbook_pid)
483df136
                         process.start(variables)
7ce9968c
                     except Exception:
                         messages.append(traceback.format_exc())
                         rc = 1
9c0275a8
 
7ce9968c
                     if rc == 0:
                         process.run()
b1b93c7a
                     else:
                         process.shutdown()
9c0275a8
 
7ce9968c
                     sys.exit(rc)
9c0275a8
 
7ce9968c
                 else:
                     os.close(w)
                     rfd = os.fdopen(r, 'r')
                     data = json.loads(rfd.read())
                     messages.extend(data.pop('messages'))
                     result.update(data)
26ec2ecf
 
9c0275a8
             else:
7ce9968c
                 messages.append('found existing local domain socket, using it!')
                 conn = Connection(socket_path)
97d4e531
                 conn.set_options(var_options=variables)
7ce9968c
                 pc_data = to_text(init_data)
                 try:
                     messages.extend(conn.update_play_context(pc_data))
                 except Exception as exc:
                     # Only network_cli has update_play context, so missing this is
                     # not fatal e.g. netconf
                     if isinstance(exc, ConnectionError) and getattr(exc, 'code', None) == -32601:
                         pass
                     else:
                         result.update({
                             'error': to_text(exc),
                             'exception': traceback.format_exc()
                         })
6fe9a5e4
 
90cd87f9
     messages.append(sys.stdout.getvalue())
9c0275a8
     result.update({
         'messages': messages,
         'socket_path': socket_path
     })
26ec2ecf
 
90cd87f9
     sys.stdout = saved_stdout
9c0275a8
     if 'exception' in result:
         rc = 1
         sys.stderr.write(json.dumps(result))
     else:
         rc = 0
         sys.stdout.write(json.dumps(result))
6fe9a5e4
 
26ec2ecf
     sys.exit(rc)
 
527fc492
 
26ec2ecf
 if __name__ == '__main__':
e20ed8bc
     display = Display()
26ec2ecf
     main()