fix wrong IP address issue
[mpi-image.git] / scripts / init.py
1 #!/usr/bin/python
2
3 import sys
4 import os
5 import signal
6 import socket
7 import fcntl
8 import pdb
9 import getpass
10
11 INSANE = 500
12 BLOCKING_DELAY = 30
13
14 HTML_SOURCE = "www.vicci.org/files/mpi/"
15 MPI_COPY_URL = HTML_SOURCE + "mpi-copy.py"
16 GENERATE_HOSTFILE_URL = HTML_SOURCE + "generate-hostfile.py"
17 SSHD_CONFIG_URL = HTML_SOURCE + "sshd_config"
18
19
20 def alarm(signum, junk):
21     # Do nothing, default implementation does needful
22     return
23
24 def fetch_output(infd,outfile):
25     try:
26         os.write(infd,'1')
27         output = ''
28
29         while (True):
30             batch = outfile.read()
31             if (batch):
32                 output+=batch
33             else:
34                 break
35     except IOError,e:
36         print "IO Error Accessing Vsys"
37         raise e
38
39     return output
40
41 def fetch_output_nofail(inpath,outpath):
42     pid = os.fork()
43     
44     output = None
45
46     if (pid==0):
47         outfd = os.open(outpath, os.O_RDONLY|os.O_NONBLOCK)
48         infd = os.open(inpath, os.O_WRONLY)
49         #infile = os.fdopen(infd,'w')
50
51         flags = fcntl.fcntl(outfd, fcntl.F_GETFL)
52         fcntl.fcntl(outfd, fcntl.F_SETFL, flags & ~os.O_NONBLOCK)
53         outfile = os.fdopen(outfd)
54
55         output = fetch_output(infd, outfile)
56     else:
57         signal.alarm(BLOCKING_DELAY)
58         try:
59             os.waitpid(pid,0)
60             exit(1)
61         except OSError:
62             ## Failed
63             pass
64
65     return output
66
67
68 ### main()
69
70 os.system("wget -q -O /usr/bin/mpi-copy.py " + MPI_COPY_URL)
71 os.system("wget -q -O /usr/bin/generate-hostfile.py " + GENERATE_HOSTFILE_URL)
72 os.system("wget -q -O /tmp/sshd_config " + SSHD_CONFIG_URL)
73
74 os.system("chmod 755 /usr/bin/mpi-copy.py")
75 os.system("chmod 755 /usr/bin/generate-hostfile.py")
76
77 for fn in ["/usr/bin/mpi-copy.py", "/usr/bin/generate-hostfile.py", "/tmp/sshd_config"]:
78     if not os.path.exists(fn):
79         print "missing", fn
80         sys.exit(-1)
81     if os.path.getsize(fn)==0:
82         print "zero byte file", fn
83         sys.exit(-1)
84
85 f = open("/etc/slicename", "r")
86 homedir = "/home/" + f.read()
87 homedir = homedir.strip()
88 s = "mkdir " + homedir + "/.ssh"
89 os.system(s)
90
91 outpath = '/vsys/%s.out'%sys.argv[1]
92 inpath = '/vsys/%s.in'%sys.argv[1]
93
94 ###
95 ### If vsys is not set up, wait a bit and try again
96
97 try:
98     counter = 1
99     while True:
100             signal.signal(signal.SIGALRM, alarm)
101             output = fetch_output_nofail(inpath,outpath)
102             if (output is not None or counter==INSANE):
103                 break
104             print "No access to vsys. Looping (%d)..."%counter
105             counter = counter+1
106
107     if (output is None):
108         print "Could not connect to Vsys. Giving up."
109         exit(1)
110
111     # get the ip address of the host, we'll need it when setting up sshd
112     hostname = socket.gethostname()
113     hostipaddr = socket.gethostbyname(hostname)
114     print "found hostname", hostname, "and ip", hostipaddr
115
116     wholestring = output.split("\n")
117     seenSSHKey = False
118     for eachline in wholestring:
119       if eachline == "\n" or eachline == "":
120         continue
121       line = eachline.split()
122       if line[0] == "vsys_sshKey:":
123         pathname = homedir + "/.ssh/id_rsa"
124         g = open(pathname, "w+")
125         s = line[1] + " " + line[2] + " " + line[3] + " " + line[4] + "\n"
126     #    for i in line[5:len(line)-4]:
127     #      str = str + i + "\n"
128     #    for i in line[len(line)-4:]:
129     #      str = str + i + " "
130     #    str = str + "\n"
131         g.write(s)
132         seenSSHKey = True
133       elif seenSSHKey:
134         if line[len(line)-1] == "KEY-----":
135           s = line[0] + " " + line[1] + " " + line[2] + " " + line[3] + "\n"
136           g.write(s)
137           seenSSHKey = False 
138           g.close()
139         else:
140           g.write(line[0] + "\n")
141       elif line[0] == "vsys_portNumber:":
142         pathname = homedir + "/.ssh/config"
143         g = open(pathname, "w+")
144         s = "Port " + line[1] + "\n"
145         g.write(s)
146         g.write("StrictHostKeyChecking no\n")
147         g.write("UserKnownHostsFile /dev/null\n")
148         g.write("LogLevel quiet\n")
149         g.close()
150         h = open("/tmp/portFile", "w+")
151         h.write(s)
152         h.close()
153         os.system("sudo rm /etc/ssh/sshd_config")
154         os.system("sudo cat /tmp/sshd_config /tmp/portFile > /tmp/newsshd_config")
155         os.system('sudo echo "ListenAddress ' + hostipaddr + '" >> /tmp/newsshd_config')
156         os.system("sudo mv /tmp/newsshd_config /etc/ssh/sshd_config")
157       elif line[0] == "vsys_sshKey.pub:":
158         pathname = homedir + "/.ssh/id_rsa.pub"
159         g = open(pathname, "w+")
160         s = ""
161         for i in line[1:]:
162           s = s  + i + " "
163         s = s.strip(" ")
164         s = s + "\n"
165         g.write(s)
166         g.close()
167         dest = homedir + "/.ssh/authorized_keys"
168         os.system("cp " + pathname + " " + dest)
169
170     username = homedir[6:]
171     s = "sudo chown -R " + username + " " + homedir + "/.ssh"
172     os.system(s)
173     s = "sudo chgrp -R slices " + homedir + "/.ssh"
174     os.system(s)
175 #    s = "sudo chown " + username + " " + homedir + "/mpi-copy.py"
176 #    os.system(s)
177 #    s = "sudo chgrp slices " + homedir + "/mpi-copy.py"
178 #    os.system(s)
179 #    s = "sudo chmod u+x " + homedir + "/mpi-copy.py"
180 #    os.system(s)
181
182     os.system("chmod og-rw " + homedir + "/.ssh/id_rsa")
183     os.system("chmod og-rw " + homedir + "/.ssh/id_rsa.pub")
184     os.system("chmod og-rw " + homedir + "/.ssh/config")
185     os.system("rm -rf /tmp/portFile")
186     os.system("sudo rm -rf /tmp/sshd_config")
187 except Exception,e:
188     print "Unexpected error:", sys.exc_info()[0]
189
190