only use the slice's public IP if NAT is enabled
[mpi-image.git] / scripts / init.py
1 #!/usr/bin/python
2
3 import sys
4 import os
5 import signal
6 import socket
7 import struct
8 import fcntl
9 import pdb
10 import getpass
11
12 INSANE = 500
13 BLOCKING_DELAY = 30
14
15 HTML_SOURCE = "www.vicci.org/files/mpi/"
16 MPI_COPY_URL = HTML_SOURCE + "mpi-copy.py"
17 GENERATE_HOSTFILE_URL = HTML_SOURCE + "generate-hostfile.py"
18 SSHD_CONFIG_URL = HTML_SOURCE + "sshd_config"
19 ORTED_URL = HTML_SOURCE + "orted"
20 BIND_PUBLIC_URL = HTML_SOURCE + "bind_public.so"
21 MPIRUN_URL = HTML_SOURCE + "mpirun"
22
23 def alarm(signum, junk):
24     # Do nothing, default implementation does needful
25     return
26
27 def fetch_output(infd,outfile):
28     try:
29         os.write(infd,'1')
30         output = ''
31
32         while (True):
33             batch = outfile.read()
34             if (batch):
35                 output+=batch
36             else:
37                 break
38     except IOError,e:
39         print "IO Error Accessing Vsys"
40         raise e
41
42     return output
43
44 def fetch_output_nofail(inpath,outpath):
45     pid = os.fork()
46     
47     output = None
48
49     if (pid==0):
50         outfd = os.open(outpath, os.O_RDONLY|os.O_NONBLOCK)
51         infd = os.open(inpath, os.O_WRONLY)
52         #infile = os.fdopen(infd,'w')
53
54         flags = fcntl.fcntl(outfd, fcntl.F_GETFL)
55         fcntl.fcntl(outfd, fcntl.F_SETFL, flags & ~os.O_NONBLOCK)
56         outfile = os.fdopen(outfd)
57
58         output = fetch_output(infd, outfile)
59     else:
60         signal.alarm(BLOCKING_DELAY)
61         try:
62             os.waitpid(pid,0)
63             exit(1)
64         except OSError:
65             ## Failed
66             pass
67
68     return output
69
70 def get_ip(ifname):
71     # http://stackoverflow.com/questions/166506/finding-local-ip-addresses-using-pythons-stdlib
72     s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
73     return socket.inet_ntoa(fcntl.ioctl(
74         s.fileno(),\r
75         0x8915,  # SIOCGIFADDR\r
76         struct.pack('256s', ifname[:15]))[20:24])
77
78 ### main()
79
80 os.system("sudo mkdir -m 777 /usr/local/mpifix")
81
82 os.system("wget -q -O /usr/bin/mpi-copy.py " + MPI_COPY_URL)
83 os.system("wget -q -O /usr/bin/generate-hostfile.py " + GENERATE_HOSTFILE_URL)
84 os.system("wget -q -O /tmp/sshd_config " + SSHD_CONFIG_URL)
85 os.system("wget -q -O /usr/local/mpifix/orted " + ORTED_URL)
86 os.system("wget -q -O /usr/local/mpifix/bind_public.so " + BIND_PUBLIC_URL)
87 os.system("wget -q -O /usr/local/mpifix/mpirun " + MPIRUN_URL)
88
89 os.system("chmod 755 /usr/bin/mpi-copy.py")
90 os.system("chmod 755 /usr/bin/generate-hostfile.py")
91 os.system("chmod 755 /usr/local/mpifix/orted")
92 os.system("chmod 755 /usr/local/mpifix/mpirun")
93
94 for fn in ["/usr/bin/mpi-copy.py", "/usr/bin/generate-hostfile.py", "/tmp/sshd_config"]:
95     if not os.path.exists(fn):
96         print "missing", fn
97         sys.exit(-1)
98     if os.path.getsize(fn)==0:
99         print "zero byte file", fn
100         sys.exit(-1)
101
102 f = open("/etc/slicename", "r")
103 homedir = "/home/" + f.read()
104 homedir = homedir.strip()
105 s = "mkdir " + homedir + "/.ssh"
106 os.system(s)
107
108 outpath = '/vsys/%s.out'%sys.argv[1]
109 inpath = '/vsys/%s.in'%sys.argv[1]
110
111 ###
112 ### If vsys is not set up, wait a bit and try again
113
114 try:
115     counter = 1
116     while True:
117             signal.signal(signal.SIGALRM, alarm)
118             output = fetch_output_nofail(inpath,outpath)
119             if (output is not None or counter==INSANE):
120                 break
121             print "No access to vsys. Looping (%d)..."%counter
122             counter = counter+1
123
124     if (output is None):
125         print "Could not connect to Vsys. Giving up."
126         exit(1)
127
128     print "successfully retrieved vsys output"
129
130     try:
131         eth0_ip = get_ip("eth0")
132     except:
133         print "Failed to get local ip"
134         eth0_ip = None
135
136     # get the ip address of the host, we'll need it when setting up sshd
137     hostname = socket.gethostname()
138     hostipaddr = socket.gethostbyname(hostname)
139
140     # See if we're behind a nat. If we are, then we want to use the public ip
141     # for ssh.
142     use_public_ip = (eth0_ip is not None) and eth0_ip.startswith("192.168")
143
144     print "hostname:", hostname, "public_ip:", hostipaddr, "eth0_ip:", eth0_ip
145
146     wholestring = output.split("\n")
147     seenSSHKey = False
148     for eachline in wholestring:
149       if eachline == "\n" or eachline == "":
150         continue
151       line = eachline.split()
152       if line[0] == "vsys_sshKey:":
153         pathname = homedir + "/.ssh/id_rsa"
154         g = open(pathname, "w+")
155         s = line[1] + " " + line[2] + " " + line[3] + " " + line[4] + "\n"
156     #    for i in line[5:len(line)-4]:
157     #      str = str + i + "\n"
158     #    for i in line[len(line)-4:]:
159     #      str = str + i + " "
160     #    str = str + "\n"
161         g.write(s)
162         seenSSHKey = True
163       elif seenSSHKey:
164         if line[len(line)-1] == "KEY-----":
165           s = line[0] + " " + line[1] + " " + line[2] + " " + line[3] + "\n"
166           g.write(s)
167           seenSSHKey = False
168           g.close()
169         else:
170           g.write(line[0] + "\n")
171       elif line[0] == "vsys_portNumber:":
172         pathname = homedir + "/.ssh/config"
173         g = open(pathname, "w+")
174         s = "Port " + line[1] + "\n"
175         g.write(s)
176         g.write("StrictHostKeyChecking no\n")
177         g.write("UserKnownHostsFile /dev/null\n")
178         g.write("LogLevel quiet\n")
179         g.close()
180         h = open("/tmp/portFile", "w+")
181         h.write(s)
182         h.close()
183         os.system("sudo rm /etc/ssh/sshd_config")
184         os.system("sudo cat /tmp/sshd_config /tmp/portFile > /tmp/newsshd_config")
185         if use_public_ip:
186             os.system('sudo echo "ListenAddress ' + hostipaddr + '" >> /tmp/newsshd_config')
187         os.system("sudo mv /tmp/newsshd_config /etc/ssh/sshd_config")
188       elif line[0] == "vsys_sshKey.pub:":
189         pathname = homedir + "/.ssh/id_rsa.pub"
190         g = open(pathname, "w+")
191         s = ""
192         for i in line[1:]:
193           s = s  + i + " "
194         s = s.strip(" ")
195         s = s + "\n"
196         g.write(s)
197         g.close()
198         dest = homedir + "/.ssh/authorized_keys"
199         os.system("cp " + pathname + " " + dest)
200
201     username = homedir[6:]
202     s = "sudo chown -R " + username + " " + homedir + "/.ssh"
203     os.system(s)
204     s = "sudo chgrp -R slices " + homedir + "/.ssh"
205     os.system(s)
206 #    s = "sudo chown " + username + " " + homedir + "/mpi-copy.py"
207 #    os.system(s)
208 #    s = "sudo chgrp slices " + homedir + "/mpi-copy.py"
209 #    os.system(s)
210 #    s = "sudo chmod u+x " + homedir + "/mpi-copy.py"
211 #    os.system(s)
212
213     # /etc/profile isn't working for noninteractive ssh sessions, so set the mpi
214     # paths in .bashrc
215     if use_public_ip:
216         # If we're behind a NAT, then we need to play tricks with rewriting the IP
217         # addresses.
218         open(homedir+"/.bashrc","w").write("export PATH=/usr/local/mpifix:$PATH:/usr/lib64/openmpi/bin\nexport LD_LIBRARY_PATH=/usr/lib64/openmpi/lib:$LD_LIBRARY_PATH\n")
219     else:
220         open(homedir+"/.bashrc","w").write("export PATH=$PATH:/usr/lib64/openmpi/bin\nexport LD_LIBRARY_PATH=/usr/lib64/openmpi/lib:$LD_LIBRARY_PATH\n")
221
222     os.system("chmod og-rw " + homedir + "/.ssh/id_rsa")
223     os.system("chmod og-rw " + homedir + "/.ssh/id_rsa.pub")
224     os.system("chmod og-rw " + homedir + "/.ssh/config")
225     os.system("rm -rf /tmp/portFile")
226     os.system("sudo rm -rf /tmp/sshd_config")
227 except Exception,e:
228     print "Unexpected error:", sys.exc_info()[0]
229
230