use inheritance to model the various kinds of RegRecords
Thierry Parmentelat [Thu, 19 Jan 2012 17:34:49 +0000 (18:34 +0100)]
sfa/importer/sfa-import-plc.py
sfa/importer/sfaImport.py
sfa/managers/registry_manager.py
sfa/storage/persistentobjs.py

index 9076237..1c38bc5 100755 (executable)
@@ -28,7 +28,7 @@ from sfa.trust.certificate import convert_public_key, Keypair
 from sfa.plc.plshell import PlShell    
 
 from sfa.storage.alchemy import dbsession
-from sfa.storage.persistentobjs import RegRecord
+from sfa.storage.persistentobjs import RegRecord, RegAuthority, RegUser, RegSlice, RegNode
 
 from sfa.importer.sfaImport import sfaImport, _cleanup_string
 
@@ -138,27 +138,26 @@ def main():
         i2site = {'name': 'Internet2', 'abbreviated_name': 'I2',
                     'login_base': 'internet2', 'site_id': -1}
         site_hrn = _get_site_hrn(interface_hrn, i2site)
-        logger.info("Importing site: %s" % site_hrn)
         # import if hrn is not in list of existing hrns or if the hrn exists
         # but its not a site record
         if site_hrn not in existing_hrns or \
            (site_hrn, 'authority') not in existing_records:
-            logger.info("Import: site %s " % site_hrn)
             urn = hrn_to_urn(site_hrn, 'authority')
             if not sfaImporter.AuthHierarchy.auth_exists(urn):
                 sfaImporter.AuthHierarchy.create_auth(urn)
             auth_info = sfaImporter.AuthHierarchy.get_auth_info(urn)
-            auth_record = RegRecord("authority", hrn=site_hrn, gid=auth_info.get_gid_object(),
-                                    pointer=site['site_id'], 
-                                    authority=get_authority(site_hrn))
-            logger.info("Import: Importing auth %s"%auth_record)
+            auth_record = RegAuthority()
+            auth_record.hrn=site_hrn
+            auth_record.gid=auth_info.get_gid_object()
+            auth_record.pointer=site['site_id']
+            auth_record.authority=get_authority(site_hrn)
             dbsession.add(auth_record)
             dbsession.commit()
+            logger.info("Import: Imported authority (vini site) %s"%auth_record)
 
     # start importing 
     for site in sites:
         site_hrn = _get_site_hrn(interface_hrn, site)
-        logger.info("Importing site: %s" % site_hrn)
 
         # import if hrn is not in list of existing hrns or if the hrn exists
         # but its not a site record
@@ -169,12 +168,14 @@ def main():
                 if not sfaImporter.AuthHierarchy.auth_exists(urn):
                     sfaImporter.AuthHierarchy.create_auth(urn)
                 auth_info = sfaImporter.AuthHierarchy.get_auth_info(urn)
-                auth_record = RegRecord("authority", hrn=site_hrn, gid=auth_info.get_gid_object(),
-                                        pointer=site['site_id'], 
-                                        authority=get_authority(site_hrn))
-                logger.info("Import: importing site: %s" % auth_record)  
+                auth_record = RegAuthority()
+                auth_record.hrn=site_hrn
+                auth_record.gid=auth_info.get_gid_object()
+                auth_record.pointer=site['site_id']
+                auth_record.authority=get_authority(site_hrn)
                 dbsession.add(auth_record)
                 dbsession.commit()
+                logger.info("Import: imported authority (site) : %s" % auth_record)  
             except:
                 # if the site import fails then there is no point in trying to import the
                 # site's child records (node, slices, persons), so skip them.
@@ -197,12 +198,14 @@ def main():
                     pkey = Keypair(create=True)
                     urn = hrn_to_urn(hrn, 'node')
                     node_gid = sfaImporter.AuthHierarchy.create_gid(urn, create_uuid(), pkey)
-                    node_record = RegRecord("node", hrn=hrn, gid=node_gid,
-                                            pointer=node['node_id'], 
-                                            authority=get_authority(hrn))    
-                    logger.info("Import: importing node: %s" % node_record)  
+                    node_record = RegNode ()
+                    node_record.hrn=hrn
+                    node_record.gid=node_gid
+                    node_record.pointer =node['node_id']
+                    node_record.authority=get_authority(hrn)
                     dbsession.add(node_record)
                     dbsession.commit()
+                    logger.info("Import: imported node: %s" % node_record)  
                 except:
                     logger.log_exc("Import: failed to import node") 
                     
@@ -221,12 +224,14 @@ def main():
                     pkey = Keypair(create=True)
                     urn = hrn_to_urn(hrn, 'slice')
                     slice_gid = sfaImporter.AuthHierarchy.create_gid(urn, create_uuid(), pkey)
-                    slice_record = RegRecord("slice", hrn=hrn, gid=slice_gid, 
-                                             pointer=slice['slice_id'],
-                                             authority=get_authority(hrn))
-                    logger.info("Import: importing slice: %s" % slice_record)  
+                    slice_record = RegSlice ()
+                    slice_record.hrn=hrn
+                    slice_record.gid=slice_gid
+                    slice_record.pointer=slice['slice_id']
+                    slice_record.authority=get_authority(hrn)
                     dbsession.add(slice_record)
                     dbsession.commit()
+                    logger.info("Import: imported slice: %s" % slice_record)  
                 except:
                     logger.log_exc("Import: failed to  import slice")
 
@@ -268,12 +273,14 @@ def main():
                         pkey = Keypair(create=True) 
                     urn = hrn_to_urn(hrn, 'user')
                     person_gid = sfaImporter.AuthHierarchy.create_gid(urn, create_uuid(), pkey)
-                    person_record = RegRecord("user", hrn=hrn, gid=person_gid,
-                                              pointer=person['person_id'], 
-                                              authority=get_authority(hrn))
-                    logger.info("Import: importing person: %s" % person_record)
+                    person_record = RegUser ()
+                    person_record.hrn=hrn
+                    person_record.gid=person_gid
+                    person_record.pointer=person['person_id']
+                    person_record.authority=get_authority(hrn)
                     dbsession.add (person_record)
                     dbsession.commit()
+                    logger.info("Import: imported person: %s" % person_record)
                 except:
                     logger.log_exc("Import: failed to import person.") 
     
index 3f8284d..dbcb66d 100644 (file)
@@ -16,7 +16,8 @@ from sfa.trust.certificate import convert_public_key, Keypair
 from sfa.trust.trustedroots import TrustedRoots
 from sfa.trust.hierarchy import Hierarchy
 from sfa.trust.gid import create_uuid
-from sfa.storage.persistentobjs import RegRecord
+from sfa.storage.persistentobjs import RegRecord, RegAuthority, RegUser
+from sfa.storage.persistentobjs import RegTmpAuthSa, RegTmpAuthAm, RegTmpAuthSm
 from sfa.storage.alchemy import dbsession
 
 def _un_unicode(str):
@@ -65,7 +66,8 @@ class sfaImport:
 
         # create interface records
         self.logger.info("Import: creating interface records")
-        self.create_interface_records()
+# xxx authority+ turning off the creation of authority+*
+#        self.create_interface_records()
 
         # add local root authority's cert  to trusted list
         self.logger.info("Import: adding " + interface_hrn + " to trusted list")
@@ -87,12 +89,14 @@ class sfaImport:
         self.AuthHierarchy.create_top_level_auth(hrn)    
         # create the db record if it doesnt already exist    
         auth_info = self.AuthHierarchy.get_auth_info(hrn)
-        auth_record = RegRecord("authority", hrn=hrn, gid=auth_info.get_gid_object(), 
-                                authority=get_authority(hrn))
+        auth_record = RegAuthority()
+        auth_record.hrn=hrn
+        auth_record.gid=auth_info.get_gid_object()
+        auth_record.authority=get_authority(hrn)
         auth_record.just_created()
-        self.logger.info("Import: importing auth %s " % auth_record)
         dbsession.add (auth_record)
         dbsession.commit()
+        self.logger.info("Import: imported authority (parent) %s " % auth_record)
 
     def create_sm_client_record(self):
         """
@@ -105,13 +109,16 @@ class sfaImport:
             self.AuthHierarchy.create_auth(urn)
 
         auth_info = self.AuthHierarchy.get_auth_info(hrn)
-        user_record = RegRecord("user", hrn=hrn, gid=auth_info.get_gid_object(), \
-                                   authority=get_authority(hrn))
+        user_record = RegUser()
+        user_record.hrn=hrn
+        user_record.gid=auth_info.get_gid_object()
+        user_record.authority=get_authority(hrn)
         user_record.just_created()
-        self.logger.info("Import: importing user %s " % user_record)
         dbsession.add (user_record)
         dbsession.commit()
+        self.logger.info("Import: importing user (slicemanager) %s " % user_record)
 
+# xxx authority+ - this is currently turned off 
     def create_interface_records(self):
         """
         Create a record for each SFA interface
@@ -119,21 +126,29 @@ class sfaImport:
         # just create certs for all sfa interfaces even if they
         # arent enabled
         hrn = self.config.SFA_INTERFACE_HRN
-        interfaces = ['authority+sa', 'authority+am', 'authority+sm']
+        reg_classes_info = [ (RegTmpAuthSa, 'authority+sa'),
+                          (RegTmpAuthAm, 'authority+am'),
+                          (RegTmpAuthSm, 'authority+sm'), ]
+        # interfaces = ['authority+sa', 'authority+am', 'authority+sm']
         auth_info = self.AuthHierarchy.get_auth_info(hrn)
         pkey = auth_info.get_pkey_object()
-        for interface in interfaces:
+        for (reg_class, interface) in reg_classes_info:
             urn = hrn_to_urn(hrn, interface)
             gid = self.AuthHierarchy.create_gid(urn, create_uuid(), pkey)
-            interface_record = RegRecord(interface, hrn=hrn, gid = gid, 
-                                         authority=get_authority(hrn))
+            # xxx this should probably use a RegAuthority, or a to-be-defined RegPeer object
+            # but for now we have to preserve the authority+<> stuff
+            interface_record = reg_class()
+            #interface_record = RegAuthority()
+            interface_record.hrn=hrn
+            interface_record.gid= gid
+            interface_record.authority=get_authority(hrn)
             interface_record.just_created()
-            self.logger.info("Import: importing %s " % interface_record)
             dbsession.add (interface_record)
             dbsession.commit()
+            self.logger.info("Import: imported authority (%s) %s " % (interface,interface_record))
              
     def delete_record(self, hrn, type):
         # delete the record
         for rec in dbsession.query(RegRecord).filter_by(type=type,hrn=hrn):
-           del rec
+           dbsession.delete(rec)
         dbsession.commit()
index e40772d..7c92904 100644 (file)
@@ -18,7 +18,7 @@ from sfa.trust.credential import Credential
 from sfa.trust.certificate import Certificate, Keypair, convert_public_key
 from sfa.trust.gid import create_uuid
 
-from sfa.storage.persistentobjs import RegRecord
+from sfa.storage.persistentobjs import make_record,RegRecord
 from sfa.storage.alchemy import dbsession
 
 class RegistryManager:
@@ -52,7 +52,7 @@ class RegistryManager:
         # get record info
         record=dbsession.query(RegRecord).filter_by(type=type,hrn=hrn).first()
         if not record:
-            raise RecordNotFound(hrn)
+            raise RecordNotFound("hrn=%s, type=%s"%(hrn,type))
     
         # verify_cancreate_credential requires that the member lists
         # (researchers, pis, etc) be filled in
@@ -142,11 +142,13 @@ class RegistryManager:
     
         # try resolving the remaining unfound records at the local registry
         local_hrns = list ( set(hrns).difference([record['hrn'] for record in records]) )
+        logger.info("Resolve: local_hrns=%s"%local_hrns)
         # 
         local_records = dbsession.query(RegRecord).filter(RegRecord.hrn.in_(local_hrns))
         if intype:
             local_records = local_records.filter_by(type=intype)
         local_records=local_records.all()
+        logger.info("Resolve: local_records=%s (intype=%s)"%(local_records,intype))
         local_dicts = [ record.__dict__ for record in local_records ]
         
         if full:
@@ -261,7 +263,8 @@ class RegistryManager:
             raise ExistingRecord(hrn)
            
         assert ('type' in record_dict)
-        record = RegRecord(dict=record_dict)
+        # returns the right type of RegRecord according to type in record
+        record = make_record(record_dict)
         record.just_created()
         record.authority = get_authority(record.hrn)
         auth_info = api.auth.get_auth_info(record.authority)
@@ -311,7 +314,7 @@ class RegistryManager:
         # make sure the record exists
         record = dbsession.query(RegRecord).filter_by(type=type,hrn=hrn).first()
         if not record:
-            raise RecordNotFound(hrn)
+            raise RecordNotFound("hrn=%s, type=%s"%(hrn,type))
         record.just_updated()
     
         # validate the type
index 349d70d..66703ab 100644 (file)
@@ -62,7 +62,6 @@ class AlchemyObj:
             if isinstance(v, StringTypes) and v.lower() in ['true']: v=True
             if isinstance(v, StringTypes) and v.lower() in ['false']: v=False
             setattr(self,k,v)
-        assert self.type in BUILTIN_TYPES
     
     # in addition we provide convenience for converting to and from xml records
     # for this purpose only, we need the subclasses to define 'fields' as either 
@@ -75,7 +74,6 @@ class AlchemyObj:
         xml_record = XML(xml)
         xml_dict = xml_record.todict()
         logger.info("load from xml, keys=%s"%xml_dict.keys())
-#        for k in self.xml_fields():
         for (k,v) in xml_dict.iteritems():
             setattr(self,k,v)
 
@@ -108,61 +106,52 @@ class AlchemyObj:
 
 
 ##############################
-class Type (Base):
-    __table__ = Table ('types', Base.metadata,
-                       Column ('type',String, primary_key=True),
-                       )
-    def __init__ (self, type): self.type=type
-    def __repr__ (self): return "<Type %s>"%self.type
-    
-#BUILTIN_TYPES = [ 'authority', 'slice', 'node', 'user' ]
-# xxx for compat but sounds useless
-BUILTIN_TYPES = [ 'authority', 'slice', 'node', 'user',
-                  'authority+sa', 'authority+am', 'authority+sm' ]
-
-def insert_builtin_types(dbsession):
-    for type in BUILTIN_TYPES :
-        count = dbsession.query (Type).filter_by (type=type).count()
-        if count==0:
-            dbsession.add (Type (type))
-    dbsession.commit()
+# various kinds of records are implemented as an inheritance hierarchy
+# RegRecord is the base class for all actual variants
 
-##############################
 class RegRecord (Base,AlchemyObj):
     # xxx tmp would be 'records'
-    __table__ = Table ('records', Base.metadata,
-                       Column ('record_id', Integer, primary_key=True),
-                       Column ('type', String, ForeignKey ("types.type")),
-                       Column ('hrn',String),
-                       Column ('gid',String),
-                       Column ('authority',String),
-                       Column ('peer_authority',String),
-                       Column ('pointer',Integer,default=-1),
-                       Column ('date_created',DateTime),
-                       Column ('last_updated',DateTime),
-                       )
+    __tablename__       = 'records'
+    record_id           = Column (Integer, primary_key=True)
+    type                = Column (String)
+    hrn                 = Column (String)
+    gid                 = Column (String)
+    authority           = Column (String)
+    peer_authority      = Column (String)
+    pointer             = Column (Integer, default=-1)
+    date_created        = Column (DateTime)
+    last_updated        = Column (DateTime)
+    # use the 'type' column to decide which subclass the object is of
+    __mapper_args__     = { 'polymorphic_on' : type }
+
     fields = [ 'type', 'hrn', 'gid', 'authority', 'peer_authority' ]
     def __init__ (self, type='unknown', hrn=None, gid=None, authority=None, peer_authority=None, 
                   pointer=None, dict=None):
-        self.type=type
-        if hrn: self.hrn=hrn
+# managed by alchemy's polymorphic stuff
+#        self.type=type
+        if hrn:                                 self.hrn=hrn
         if gid: 
-            if isinstance(gid, StringTypes): self.gid=gid
-            else: self.gid=gid.save_to_string(save_parents=True)
-        if authority: self.authority=authority
-        if peer_authority: self.peer_authority=peer_authority
-        if pointer: self.pointer=pointer
-        if dict:
-            self.load_from_dict (dict)
+            if isinstance(gid, StringTypes):    self.gid=gid
+            else:                               self.gid=gid.save_to_string(save_parents=True)
+        if authority:                           self.authority=authority
+        if peer_authority:                      self.peer_authority=peer_authority
+        if pointer:                             self.pointer=pointer
+        if dict:                                self.load_from_dict (dict)
 
     def __repr__(self):
-        result="[Record(record_id=%s, hrn=%s, type=%s, authority=%s, pointer=%s" % \
-                (self.record_id, self.hrn, self.type, self.authority, self.pointer)
-        if self.gid: result+=" %s..."%self.gid[:10]
-        else: result+=" no-gid"
+        result="[Record id=%s, type=%s, hrn=%s, authority=%s, pointer=%s" % \
+                (self.record_id, self.type, self.hrn, self.authority, self.pointer)
+        # skip the uniform '--- BEGIN CERTIFICATE --' stuff
+        if self.gid: result+=" gid=%s..."%self.gid[28:36]
+        else: result+=" nogid"
         result += "]"
         return result
 
+    @validates ('gid')
+    def validate_gid (self, key, gid):
+        if isinstance(gid, StringTypes):    return gid
+        else:                               return gid.save_to_string(save_parents=True)
+
     # xxx - there might be smarter ways to handle get/set'ing gid using validation hooks 
     def get_gid_object (self):
         if not self.gid: return None
@@ -178,42 +167,68 @@ class RegRecord (Base,AlchemyObj):
         self.last_updated=now
 
 ##############################
-
-class User (Base):
-    __table__ = Table ('users', Base.metadata,
-                       Column ('record_id', Integer, ForeignKey ("records.record_id"), primary_key=True),
-                       Column ('email', String),
-                       )
-    def __init__ (self, email):
-        self.email=email
-    def __repr__ (self): return "[User(%d) email=%s>"%(self.record_id,self.email,)
+class RegUser (RegRecord):
+    __tablename__       = 'users'
+    # these objects will have type='user' in the records table
+    __mapper_args__     = { 'polymorphic_identity' : 'user' }
+    record_id           = Column (Integer, ForeignKey ("records.record_id"), primary_key=True)
+    email               = Column ('email', String)
+    
+    # append stuff at the end of the record __repr__
+    def __repr__ (self): 
+        result = RegRecord.__repr__(self).replace("Record","User")
+        result.replace ("]"," email=%s"%self.email)
+        return result
     
     @validates('email') 
     def validate_email(self, key, address):
         assert '@' in address
         return address
-                           
-class Key (Base):
-    __table__ = Table ('keys', Base.metadata,
-                       Column ('key_id', Integer, primary_key=True),
-                       Column ('key',String),
-                       )
 
-##############################
-#record_table = RegRecord.__table__
-#user_table = User.__table__
-#record_user_join = join (record_table, user_table)
-#
-#class UserRecord (Base):
-#    __table__ = record_user_join
-#    record_id = column_property (record_table.c.record_id, user_table.c.record_id)
-#    user_id = user_table.c.user_id
-#    def __init__ (self, gid, email):
-#        self.type='user'
-#        self.gid=gid
-#        self.email=email
-#    def __repr__ (self): return "<UserRecord %s %s>"%(self.email,self.gid)
-#
+class RegAuthority (RegRecord):
+    __tablename__       = 'authorities'
+    __mapper_args__     = { 'polymorphic_identity' : 'authority' }
+    record_id           = Column (Integer, ForeignKey ("records.record_id"), primary_key=True)
+    
+    # no proper data yet, just hack the typename
+    def __repr__ (self):
+        return RegRecord.__repr__(self).replace("Record","Authority")
+
+class RegSlice (RegRecord):
+    __tablename__       = 'slices'
+    __mapper_args__     = { 'polymorphic_identity' : 'slice' }
+    record_id           = Column (Integer, ForeignKey ("records.record_id"), primary_key=True)
+    
+    def __repr__ (self):
+        return RegRecord.__repr__(self).replace("Record","Slice")
+
+class RegNode (RegRecord):
+    __tablename__       = 'nodes'
+    __mapper_args__     = { 'polymorphic_identity' : 'node' }
+    record_id           = Column (Integer, ForeignKey ("records.record_id"), primary_key=True)
+    
+    def __repr__ (self):
+        return RegRecord.__repr__(self).replace("Record","Node")
+
+# because we use 'type' as the discriminator here, the only way to have type set to
+# e.g. authority+sa is to define a separate class
+# this currently is not used at all though, just to check if all this stuff really is useful
+# if so it would make more sense to store that in the authorities table instead
+class RegTmpAuthSa (RegRecord):
+    __tablename__       = 'authorities_sa'
+    __mapper_args__     = { 'polymorphic_identity' : 'authority+sa' }
+    record_id           = Column (Integer, ForeignKey ("records.record_id"), primary_key=True)
+
+class RegTmpAuthAm (RegRecord):
+    __tablename__       = 'authorities_am'
+    __mapper_args__     = { 'polymorphic_identity' : 'authority+am' }
+    record_id           = Column (Integer, ForeignKey ("records.record_id"), primary_key=True)
+
+class RegTmpAuthSm (RegRecord):
+    __tablename__       = 'authorities_sm'
+    __mapper_args__     = { 'polymorphic_identity' : 'authority+sm' }
+    record_id           = Column (Integer, ForeignKey ("records.record_id"), primary_key=True)
+
 ##############################
 def init_tables(dbsession):
     logger.info("Initializing db schema and builtin types")
@@ -224,10 +239,31 @@ def init_tables(dbsession):
     # so let's import alchemy - but not from toplevel 
     from sfa.storage.alchemy import engine
     Base.metadata.create_all(engine)
-    insert_builtin_types(dbsession)
 
 def drop_tables(dbsession):
     logger.info("Dropping tables")
     # same as for init_tables
     from sfa.storage.alchemy import engine
     Base.metadata.drop_all(engine)
+
+# convert an incoming record - typically from xmlrpc - into an object
+def make_record (record_dict):
+    assert ('type' in record_dict)
+    type=record_dict['type']
+    if type=='authority':
+        result=RegAuthority (dict=record_dict)
+    elif type=='user':
+        result=RegUser (dict=record_dict)
+    elif type=='slice':
+        result=RegSlice (dict=record_dict)
+    elif type=='node':
+        result=RegNode (dict=record_dict)
+    else:
+        result=RegRecord (dict=record_dict)
+    logger.info ("converting dict into Reg* with type=%s"%type)
+    logger.info ("returning=%s"%result)
+    # xxx todo
+    # register non-db attributes in an extensions field
+    return result
+        
+