import repository from arizona
[raven.git] / apps / ravenpublish / slicerun.py
1 #! /usr/bin/python
2
3 import os
4 import select
5 import sys
6 import threading
7 import time
8 import xml.dom.minidom
9 from optparse import OptionParser
10 from ravenlib.ravenlog import RavenLog
11
12 CAP_PER_THREAD=True
13
14 def interruptable_waitpid(pid=0, options=0):
15     while True:
16         try:
17             return os.waitpid(pid, options)
18         except OSError, e:
19             if (e.errno == 4):
20                 # EINTR, absorb it and wait again
21                 pass
22             else:
23                 raise
24
25 def UTOA(x):
26     """ un-unicode a string """\r
27     return x.encode("utf8")
28
29 class StorkInstaller(RavenLog):
30     def __init__(self, slicename = None, style="protogeni"):
31         # RavenLog object is mixed in to give us a configurable Print statement
32         RavenLog.__init__(self)
33
34         self.nodelist = []
35         self.hide_output = False
36         self.capture_output = False
37         self.slicename = slicename
38         self.sslkeys = []
39
40         self.set_slicename = True
41         self.set_sslkeys = True
42         self.install_stork = True
43         self.force_install = False
44         self.force_slicename = False
45         self.force_sslkeys = False
46
47         self.style = style
48         self.default_port = 22
49         self.default_username = None
50         self.privateKeyName = None
51
52         self.outputLock = threading.Lock()
53         self.output = {}
54         self.allData = []
55
56         # max of 64 ssh sessions open at any one time
57         self.limit=128
58
59         if (slicename) and (self.style == "sfa"):
60             parts = slicename.split(".")
61             if len(parts)>=2:
62                 self.default_username = parts[-2] + "_" + parts[-1]
63
64         self.children = {}
65
66         self.create_pgroup()
67
68     def create_pgroup(self):
69         # Create a process group, so that we can put all of the ssh processes
70         # under one group.
71
72         parent_pid = os.getpid()
73         pid = os.fork()
74         if pid == 0:
75             # child
76             while True:
77                 # the parent pid will change to 1 when the parent goes away
78                 if (os.getppid() != parent_pid):
79                     sys.exit(0)
80
81                 # do nothing
82                 time.sleep(1)
83         else:
84             self.pgid = pid
85             os.setpgid(self.pgid, self.pgid)
86
87     def load_rspec(self, rspec):
88         self.nodelist = []
89         doc = xml.dom.minidom.parseString(rspec)
90
91         if (self.style == "protogeni"):
92             self.load_rspec_protogeni(doc.documentElement)
93         elif (self.style == "sfa"):
94             self.load_rspec_sfa(doc.documentElement)
95
96     def load_rspec_protogeni(self, root):
97         self.nodelist.extend( self.load_network(root) )
98
99     def load_rspec_sfa(self, root):
100         for network in filter (lambda x: x.nodeName == "network", root.childNodes):
101             self.nodelist.extend( self.load_network(network) )
102
103     def load_network(self, network):
104         nodelist = []
105
106         for node in filter (lambda x: x.nodeName == "node", network.childNodes):
107             node_dict = {}
108
109             if node.hasAttribute("hostname"):
110                 # protogeni
111                 node_dict["hostname"] = UTOA(node.getAttribute("hostname"))
112             else:
113                 # sfa
114                 for hostnameNode in filter (lambda x: x.nodeName == "hostname", node.childNodes):
115                     if (len(hostnameNode.childNodes)>=1) and (hostnameNode.childNodes[0].nodeName=="#text"):
116                         node_dict["hostname"] = hostnameNode.childNodes[0].nodeValue
117
118             if not "hostname" in node_dict:
119                 continue
120
121             if node.hasAttribute("sshdport"):
122                 node_dict["port"] = int(node.getAttribute("sshdport"))
123
124             for services in filter (lambda x: x.nodeName == "services", node.childNodes):
125                 for login in filter (lambda x: x.nodeName == "login", services.childNodes):
126                     if login.hasAttribute("port"):
127                         # this might override the port we got above; so what
128                         node_dict["port"] = int(login.getAttribute("port"))
129
130                     if login.hasAttribute("username"):
131                         node_dict["username"] = UTOA(login.getAttribute("username"))
132
133             # SFA slices have a <Sliver> tag to indicate that the slice is
134             # on that node. No sliver tag = do not include.
135             hasSliver = False
136             for sliver in filter (lambda x: x.nodeName == "sliver", node.childNodes):
137                 hasSliver = True
138
139             if (self.style=="protogeni") or (hasSliver):
140                 nodelist.append(node_dict)
141
142         return nodelist
143
144     def build_check_script(self, hostname):
145         script = []
146
147         script.append("HOSTNAME=`hostname`")
148         script.append("if [ -e /etc/slicename ]; then")
149         script.append("SLICENAME=`cat /etc/slicename`")
150         script.append("else")
151         script.append("SLICENAME=no_slice_name")
152         script.append("fi")
153         script.append("STORKRPM=`rpm -q stork-client`")
154         script.append("echo $HOSTNAME, $SLICENAME, $STORKRPM")
155
156         return "\n".join(script)
157
158     def build_processes_script(self, hostname):
159         script = []
160
161         script.append("HOSTNAME=`hostname`")
162         script.append("if [ -e /etc/slicename ]; then")
163         script.append("SLICENAME=`cat /etc/slicename`")
164         script.append("else")
165         script.append("SLICENAME=no_slice_name")
166         script.append("fi")
167         script.append("PROCESSES=`ps aex | wc -l`")
168         script.append("echo $HOSTNAME, $PROCESSES")
169
170         return "\n".join(script)
171
172     def build_install_script(self):
173         script = []
174
175         script.append("echo \* start: `/bin/hostname`")
176
177         if self.set_slicename and self.slicename:
178             if not self.force_slicename:
179                 script.append("if [ ! -e /etc/slicename ]; then")
180
181             script.append("echo \* setting slicename")
182             script.append("echo " + self.slicename + " > /etc/slicename")
183
184             if not self.force_slicename:
185                 script.append("fi")
186
187         if self.set_sslkeys and self.sslkeys:
188             if not self.force_sslkeys:
189                 script.append("if [ ! -e /etc/storksslkeys ]; then")
190
191                 script.append("echo \* setting sslkeys")
192                 script.append("echo \# generated by \"raven slice\" > /etc/storksslkeys")
193
194                 for key in self.sslkeys:
195                     script.append("echo " + key + " >> /etc/storksslkeys")
196
197             if not self.force_sslkeys:
198                 script.append("fi")
199
200         if self.install_stork:
201             if not self.force_install:
202                 # XXX using rpm -q and checking status with $? wasn't working
203                 script.append("rpm -q stork-client")
204                 script.append("if [ $? != 0 ]; then")
205                 # script.append("if [ ! -e /usr/local/stork/bin/stork.py ]; then")
206
207             script.append("echo \* downloading stork initscript")
208             script.append("wget http://stork-repository.cs.arizona.edu:/stork-install/initscript -O /tmp/stork_initscript")
209             script.append("echo running stork initscript")
210             script.append("bash /tmp/stork_initscript")
211
212             if not self.force_install:
213                 script.append("fi")
214
215         script.append("echo \* complete: `/bin/hostname`")
216
217         return "\n".join(script)
218
219     def run_node(self, node, script):
220         if (not script) or (script == "XXX_install"):
221             script = self.build_install_script()
222
223         if (script == "XXX_check"):
224             script = self.build_check_script(hostname = node["hostname"])
225
226         if (script == "XXX_processes"):
227             script = self.build_processes_script(hostname = node["hostname"])
228
229         if CAP_PER_THREAD and self.capture_output:
230             (self.stdout_r, self.stdout_w) = os.pipe()
231
232         pid = os.fork()
233         if (pid != 0):
234             # parent
235             if CAP_PER_THREAD and self.capture_output:
236                 os.close(self.stdout_w)
237                 capturer = CaptureStreamThread(self.stdout_r, pid, self)
238
239             self.children[pid] = node
240             return
241
242         os.setpgid(0, self.pgid)
243
244         if self.capture_output:
245             os.close(self.stdout_r)
246             os.dup2(self.stdout_w, 1)
247             #os.dup2(stdout_w, 2)
248         elif self.hide_output:
249             devnull = os.open( "/dev/null", os.O_WRONLY )
250             os.dup2( devnull, 1 )\r
251             os.dup2( devnull, 2 )\r
252             os.close( devnull )
253
254         args = ["ssh", "-n", "-T", "-o", "BatchMode yes", "-o", "StrictHostKeyChecking no"]
255
256         port = node.get("port", self.default_port)
257         args.extend( ["-p", str(port)] )
258
259         username = node.get("username", self.default_username)
260         args.extend( ["-l", username] )
261
262         if self.privateKeyName:
263             args.extent( ["-i", self.privateKeyName] )
264
265         # ` and $ would be expanded before being fed into bash
266         script = script.replace("`", "\\`")
267         script = script.replace("$", "\\$")
268
269         command = "sudo bash <<EOF\n" + script + "\n" + "EOF"
270
271         args.append(node["hostname"])
272         args.append(command)
273
274         os.execv("/usr/bin/ssh", args)
275         os._exit( 1 )
276
277     def run_nodes(self, nodelist=None, script=None, show_completions=False):
278         if (not CAP_PER_THREAD) and self.capture_output:
279             (self.stdout_r, self.stdout_w) = os.pipe()
280             capturer = CaptureStreamThread(self.stdout_r, 0, self)
281
282         if not nodelist:
283             nodelist = self.nodelist
284
285         for node in nodelist:
286             self.run_node(node, script)
287
288             # only do so many SSH sessions at a time...
289             if (self.limit) and (len(self.children) > self.limit):
290                 self.wait(count=1, show_completions=show_completions, verbose=False)
291
292     def relay_output(self, pid):
293         self.outputLock.acquire()
294
295         # self.alldata is for CAP_PER_THREAD==False
296         if self.allData:
297             for line in "".join(self.allData).strip("\n").split("\n"):
298                 self.Print(line)
299             self.allData=[]
300
301         # self.output is for CAP_PER_THREAD==True
302         if pid in self.output:
303             for line in "".join(self.output[pid]).strip("\n").split("\n"):
304                 self.Print(line)
305
306         self.outputLock.release()
307
308     def wait(self, count=None, show_completions=True, verbose=True):
309         #children = self.children.copy()
310         if verbose:
311             self.Print("waiting for", len(self.children), "children to complete")
312
313         while True:
314            if len(self.children) == 0:
315                return
316
317            (pid, exitstatus) = interruptable_waitpid(-self.pgid)
318
319            if pid in self.children:
320                self.relay_output(pid)
321                if show_completions:
322                    self.Print("child completed:", self.children[pid]["hostname"])
323                del self.children[pid]
324                if (count != None):
325                    count = count - 1
326                    if (count<0):
327                        return
328
329     def run_and_wait(self, nodelist=None, script=None, show_completions=True):
330         self.run_nodes(nodelist, script, show_completions)
331         self.wait(show_completions = show_completions)
332
333     def dump(self):
334         for node in self.nodelist:
335             line =  node["hostname"]
336             if "port" in node:
337                 line = " " + str(node["port"])
338             if "username" in node:
339                 line = " " + str(node["username"])
340             self.Print(line)
341
342     def handle_output_line(self, pid, line):
343         self.outputLock.acquire()
344         if CAP_PER_THREAD:
345             data = self.output.get(pid, [])
346             data.append(line)
347             self.output[pid] = data
348         else:
349             self.allData.append(line)
350         self.outputLock.release()
351
352 class CaptureStreamThread(threading.Thread):
353     def __init__(self, stream, pid, parent):
354         threading.Thread.__init__(self)
355         self.parent = parent
356         self.stream = os.fdopen(stream, "r")
357         self.pid = pid
358         self.daemon = True
359         self.start()
360
361     def run(self):
362         """ capture the output from the stream. For each line of output, call
363             parent.handle_output_line(). Parent better deal with mutual
364             exclusion.
365         """
366         while True:
367             try:
368                 line = self.stream.readline()
369                 if not line:
370                     return
371                 self.parent.handle_output_line(self.pid, line)
372             except IOError, e:
373                 if (e.errno == 4):
374                     # EINTR
375                     pass
376                 else:
377                     print >> sys.stderr, e
378                     raise
379
380 def add_parser_options(parser):
381     # note: also used from raven.py
382     parser.add_option("", "--nohide", dest="nohide",
383          help="do not hide output of commands executed on nodes", action="store_true", default=False)
384     parser.add_option("", "--forceinstall", dest="forceinstall",
385          help="install stork even if it's already installed", action="store_true", default=False)
386     parser.add_option("", "--forceslicename", dest="forceslicename",
387          help="set the initscript even if it's already set", action="store_true", default=False)
388     parser.add_option("", "--slicename", dest="slicename",
389          help="set slicename on nodes", default=None)
390     parser.add_option("", "--sslkey", dest="sslkey",
391          help="set sslkey on nodes", default=None)
392     parser.add_option("", "--style", dest="style",
393          help="set style (protogeni|sfa)", default="protogeni")
394
395 def create_parser():
396     # Generate command line parser
397     parser = OptionParser(usage="slicerun.py [options] [exec|install] rpsec_fn",
398          description="Analyze the log files")
399
400     add_parser_options(parser)
401
402     parser.disable_interspersed_args()
403
404     return parser
405
406 def main(program, options=None, args=None):
407     if not options:
408         parser = create_parser()
409
410         (options, args) = parser.parse_args()
411
412     if len(args)<2:
413         print "syntax:", program, "[options] [dump|exec|install|check] rspec_fn"
414         sys.exit(-1)
415
416     cmd = args[0]
417     fn = args[1]
418
419     si = StorkInstaller( slicename = options.slicename, style = options.style )
420
421     if not os.path.exists(fn):
422         print "file", fn, "does not exist"
423
424     rspec = open(fn,"r").read()
425
426     si.load_rspec(rspec)
427
428     si.hide_output = not options.nohide
429     si.force_slicename = options.forceslicename
430     si.force_install = options.forceinstall
431
432     if options.sslkey:
433         si.sslkeys = [options.sslkey]
434
435     if (cmd == "dump"):
436         si.dump()
437     elif (cmd == "install"):
438         si.run_nodes()
439         si.wait()
440     elif (cmd == "exec"):
441         if len(args)<=2:
442             print "syntax:", program, "[options] exec rspec_fn cmd"
443             sys.exit(-1)
444         si.run_nodes(script = " ".join(args[2:]))
445         si.wait()
446     elif (cmd == "check"):
447         si.hide_output = False
448         si.run_nodes(script="XXX_check")
449         si.wait(show_completions=False)
450
451
452 if __name__=="__main__":
453    main("slicerun.py")