bin/ansible-connection
26ec2ecf
 #!/usr/bin/env python
 
 # (c) 2016, Ansible, Inc. <support@ansible.com>
 #
 # This file is part of Ansible
 #
 # Ansible is free software: you can redistribute it and/or modify
 # it under the terms of the GNU General Public License as published by
 # the Free Software Foundation, either version 3 of the License, or
 # (at your option) any later version.
 #
 # Ansible is distributed in the hope that it will be useful,
 # but WITHOUT ANY WARRANTY; without even the implied warranty of
 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 # GNU General Public License for more details.
 #
 # You should have received a copy of the GNU General Public License
 # along with Ansible.  If not, see <http://www.gnu.org/licenses/>.
 
 ########################################################
 from __future__ import (absolute_import, division, print_function)
 
3f949358
 __metaclass__ = type
26ec2ecf
 __requires__ = ['ansible']
3f949358
 
26ec2ecf
 try:
     import pkg_resources
 except Exception:
     pass
 
 import fcntl
 import os
 import shlex
 import signal
 import socket
 import struct
 import sys
 import time
 import traceback
6fe9a5e4
 import syslog
 import datetime
c093d146
 import logging
26ec2ecf
 
 from io import BytesIO
 
 from ansible import constants as C
 from ansible.module_utils._text import to_bytes, to_native
 from ansible.module_utils.six.moves import cPickle, StringIO
 from ansible.playbook.play_context import PlayContext
 from ansible.plugins import connection_loader
 from ansible.utils.path import unfrackpath, makedirs_safe
66736730
 from ansible.errors import AnsibleConnectionFailure
e4f052c1
 from ansible.utils.display import Display
26ec2ecf
 
6fe9a5e4
 
26ec2ecf
 def do_fork():
     '''
     Does the required double fork for a daemon process. Based on
     http://code.activestate.com/recipes/66012-fork-a-daemon-process-on-unix/
     '''
     try:
         pid = os.fork()
         if pid > 0:
             return pid
 
f0bd79d4
         #os.chdir("/")
26ec2ecf
         os.setsid()
         os.umask(0)
 
         try:
             pid = os.fork()
             if pid > 0:
                 sys.exit(0)
 
3ff2c471
             if C.DEFAULT_LOG_PATH != '':
                 out_file = file(C.DEFAULT_LOG_PATH, 'a+')
                 err_file = file(C.DEFAULT_LOG_PATH, 'a+', 0)
             else:
                 out_file = file('/dev/null', 'a+')
                 err_file = file('/dev/null', 'a+', 0)
 
             os.dup2(out_file.fileno(), sys.stdout.fileno())
             os.dup2(err_file.fileno(), sys.stderr.fileno())
26ec2ecf
             os.close(sys.stdin.fileno())
 
             return pid
         except OSError as e:
             sys.exit(1)
     except OSError as e:
         sys.exit(1)
 
 def send_data(s, data):
     packed_len = struct.pack('!Q',len(data))
     return s.sendall(packed_len + data)
 
 def recv_data(s):
     header_len = 8 # size of a packed unsigned long long
     data = b""
     while len(data) < header_len:
         d = s.recv(header_len - len(data))
         if not d:
             return None
         data += d
     data_len = struct.unpack('!Q',data[:header_len])[0]
     data = data[header_len:]
     while len(data) < data_len:
         d = s.recv(data_len - len(data))
         if not d:
             return None
         data += d
     return data
 
c093d146
 
26ec2ecf
 class Server():
6fe9a5e4
 
26ec2ecf
     def __init__(self, path, play_context):
6e9244a9
 
26ec2ecf
         self.path = path
         self.play_context = play_context
c093d146
 
1b188c1f
         display.display(
             'creating new control socket for host %s:%s as user %s' %
             (play_context.remote_addr, play_context.port, play_context.remote_user),
             log_only=True
         )
 
         display.display('control socket path is %s' % path, log_only=True)
f0bd79d4
         display.display('current working directory is %s' % os.getcwd(), log_only=True)
26ec2ecf
 
6fe9a5e4
         self._start_time = datetime.datetime.now()
 
e4f052c1
         display.display("using connection plugin %s" % self.play_context.connection, log_only=True)
13805154
 
66736730
         self.conn = connection_loader.get(play_context.connection, play_context, sys.stdin)
d668d531
         self.conn._connect()
         if not self.conn.connected:
1b188c1f
             raise AnsibleConnectionFailure('unable to connect to remote host %s' % self._play_context.remote_addr)
6fe9a5e4
 
eed24079
         connection_time = datetime.datetime.now() - self._start_time
1b188c1f
         display.display('connection established to %s in %s' % (play_context.remote_addr, connection_time), log_only=True)
eed24079
 
66736730
         self.socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
         self.socket.bind(path)
         self.socket.listen(1)
26ec2ecf
 
         signal.signal(signal.SIGALRM, self.alarm_handler)
 
6fe9a5e4
     def dispatch(self, obj, name, *args, **kwargs):
         meth = getattr(obj, name, None)
         if meth:
             return meth(*args, **kwargs)
 
26ec2ecf
     def alarm_handler(self, signum, frame):
         '''
         Alarm handler
         '''
         # FIXME: this should also set internal flags for other
         #        areas of code to check, so they can terminate
         #        earlier than the socket going back to the accept
         #        call and failing there.
6fe9a5e4
         #
         # hooks the connection plugin to handle any cleanup
         self.dispatch(self.conn, 'alarm_handler', signum, frame)
26ec2ecf
         self.socket.close()
 
     def run(self):
         try:
             while True:
                 # set the alarm, if we don't get an accept before it
                 # goes off we exit (via an exception caused by the socket
                 # getting closed while waiting on accept())
                 # FIXME: is this the best way to exit? as noted above in the
                 #        handler we should probably be setting a flag to check
                 #        here and in other parts of the code
                 signal.alarm(C.PERSISTENT_CONNECT_TIMEOUT)
                 try:
                     (s, addr) = self.socket.accept()
e4f052c1
                     display.display('incoming request accepted on persistent socket', log_only=True)
26ec2ecf
                     # clear the alarm
                     # FIXME: potential race condition here between the accept and
                     #        time to this call.
                     signal.alarm(0)
                 except:
                     break
 
                 while True:
                     data = recv_data(s)
                     if not data:
                         break
 
cc18296c
                     signal.alarm(self.play_context.timeout)
6fe9a5e4
 
26ec2ecf
                     rc = 255
                     try:
                         if data.startswith(b'EXEC: '):
e4f052c1
                             display.display("socket operation is EXEC", log_only=True)
26ec2ecf
                             cmd = data.split(b'EXEC: ')[1]
                             (rc, stdout, stderr) = self.conn.exec_command(cmd)
                         elif data.startswith(b'PUT: ') or data.startswith(b'FETCH: '):
                             (op, src, dst) = shlex.split(to_native(data))
                             stdout = stderr = ''
                             try:
                                 if op == 'FETCH:':
e4f052c1
                                     display.display("socket operation is FETCH", log_only=True)
26ec2ecf
                                     self.conn.fetch_file(src, dst)
                                 elif op == 'PUT:':
e4f052c1
                                     display.display("socket operation is PUT", log_only=True)
26ec2ecf
                                     self.conn.put_file(src, dst)
                                 rc = 0
                             except:
                                 pass
6fe9a5e4
                         elif data.startswith(b'CONTEXT: '):
e4f052c1
                             display.display("socket operation is CONTEXT", log_only=True)
6fe9a5e4
                             pc_data = data.split(b'CONTEXT: ')[1]
 
                             src = StringIO(pc_data)
                             pc_data = cPickle.load(src)
                             src.close()
 
                             pc = PlayContext()
                             pc.deserialize(pc_data)
 
                             self.dispatch(self.conn, 'update_play_context', pc)
                             continue
26ec2ecf
                         else:
e4f052c1
                             display.display("socket operation is UNKNOWN", log_only=True)
26ec2ecf
                             stdout = ''
                             stderr = 'Invalid action specified'
                     except:
                         stdout = ''
                         stderr = traceback.format_exc()
 
6fe9a5e4
                     signal.alarm(0)
 
e4f052c1
                     display.display("socket operation completed with rc %s" % rc, log_only=True)
6e9244a9
 
26ec2ecf
                     send_data(s, to_bytes(str(rc)))
                     send_data(s, to_bytes(stdout))
                     send_data(s, to_bytes(stderr))
                 s.close()
         except Exception as e:
1b188c1f
             display.display(traceback.format_exec(), log_only=True)
26ec2ecf
         finally:
             # when done, close the connection properly and cleanup
             # the socket file so it can be recreated
6fe9a5e4
             end_time = datetime.datetime.now()
             delta = end_time - self._start_time
e4f052c1
             display.display('shutting down control socket, connection was active for %s secs' % delta, log_only=True)
26ec2ecf
             try:
                 self.conn.close()
6fe9a5e4
                 self.socket.close()
26ec2ecf
             except Exception as e:
                 pass
             os.remove(self.path)
 
 def main():
3f949358
 
26ec2ecf
     try:
         # read the play context data via stdin, which means depickling it
         # FIXME: as noted above, we will probably need to deserialize the
         #        connection loader here as well at some point, otherwise this
         #        won't find role- or playbook-based connection plugins
         cur_line = sys.stdin.readline()
         init_data = ''
         while cur_line.strip() != '#END_INIT#':
             if cur_line  == '':
                 raise Exception("EOL found before init data was complete")
             init_data += cur_line
             cur_line = sys.stdin.readline()
         src = BytesIO(to_bytes(init_data))
         pc_data = cPickle.load(src)
 
         pc = PlayContext()
         pc.deserialize(pc_data)
     except Exception as e:
         # FIXME: better error message/handling/logging
942ed146
         sys.stderr.write(traceback.format_exc())
         sys.exit("FAIL: %s" % e)
26ec2ecf
 
13805154
     ssh = connection_loader.get('ssh', class_only=True)
     m = ssh._create_control_path(pc.remote_addr, pc.port, pc.remote_user)
26ec2ecf
 
     # create the persistent connection dir if need be and create the paths
     # which we will be using later
     tmp_path = unfrackpath("$HOME/.ansible/pc")
     makedirs_safe(tmp_path)
     lk_path = unfrackpath("%s/.ansible_pc_lock" % tmp_path)
13805154
     sf_path = unfrackpath(m % dict(directory=tmp_path))
26ec2ecf
 
     # if the socket file doesn't exist, spin up the daemon process
     lock_fd = os.open(lk_path, os.O_RDWR|os.O_CREAT, 0o600)
     fcntl.lockf(lock_fd, fcntl.LOCK_EX)
     if not os.path.exists(sf_path):
         pid = do_fork()
         if pid == 0:
e05b2b56
             rc = 0
66736730
             try:
                 server = Server(sf_path, pc)
e05b2b56
             except AnsibleConnectionFailure as exc:
1b188c1f
                 display.display('connecting to host %s returned an error' % pc.remote_addr, log_only=True)
                 display.display(str(exc), log_only=True)
e05b2b56
                 rc = 1
66736730
             except Exception as exc:
1b188c1f
                 display.display('failed to create control socket for host %s' % pc.remote_addr, log_only=True)
                 display.display(traceback.format_exc(), log_only=True)
e05b2b56
                 rc = 1
6fe9a5e4
             fcntl.lockf(lock_fd, fcntl.LOCK_UN)
             os.close(lock_fd)
e05b2b56
             if rc == 0:
                 server.run()
             sys.exit(rc)
c093d146
     else:
1b188c1f
         display.display('re-using existing socket for %s@%s:%s' % (pc.remote_user, pc.remote_addr, pc.port), log_only=True)
26ec2ecf
     fcntl.lockf(lock_fd, fcntl.LOCK_UN)
     os.close(lock_fd)
 
     # now connect to the daemon process
     # FIXME: if the socket file existed but the daemonized process was killed,
     #        the connection will timeout here. Need to make this more resilient.
     rc = 0
     while rc == 0:
         data = sys.stdin.readline()
         if data == '':
             break
         if data.strip() == '':
             continue
         sf = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
         attempts = 1
         while True:
             try:
                 sf.connect(sf_path)
                 break
             except socket.error:
                 # FIXME: better error handling/logging/message here
6fe9a5e4
                 time.sleep(C.PERSISTENT_CONNECT_INTERVAL)
26ec2ecf
                 attempts += 1
6fe9a5e4
                 if attempts > C.PERSISTENT_CONNECT_RETRIES:
8f6fd60f
                     display.display('number of connection attempts exceeded, unable to connect to control socket', pc.remote_addr, pc.remote_user, log_only=True)
                     display.display('persistent_connect_interval=%s, persistent_connect_retries=%s' % (C.PERSISTENT_CONNECT_INTERVAL, C.PERSISTENT_CONNECT_RETRIES), pc.remote_addr, pc.remote_user, log_only=True)
e05b2b56
                     sys.stderr.write('failed to connect to control socket')
26ec2ecf
                     sys.exit(255)
 
6fe9a5e4
         # send the play_context back into the connection so the connection
3f949358
         # can handle any privilege escalation activities
6fe9a5e4
         pc_data = 'CONTEXT: %s' % src.getvalue()
         send_data(sf, to_bytes(pc_data))
         src.close()
 
26ec2ecf
         send_data(sf, to_bytes(data.strip()))
6fe9a5e4
 
26ec2ecf
         rc = int(recv_data(sf), 10)
         stdout = recv_data(sf)
         stderr = recv_data(sf)
6fe9a5e4
 
26ec2ecf
         sys.stdout.write(to_native(stdout))
         sys.stderr.write(to_native(stderr))
 
         sf.close()
         break
6fe9a5e4
 
26ec2ecf
     sys.exit(rc)
 
 if __name__ == '__main__':
e4f052c1
     display = Display()
26ec2ecf
     main()