diff --git a/models.py b/models.py index 875f310..88869be 100755 --- a/models.py +++ b/models.py @@ -1,4 +1,5 @@ import ldap3 +import random MAXENTRIES = 1000000 OBJECTCLASSES = ['top', 'person', 'organizationalPerson', 'inetOrgPerson', 'posixAccount', 'shadowAccount'] @@ -8,37 +9,54 @@ class ldapbind(): def __init__(self, ldap_host: str, admin_user: str, admin_pass: str, base: str): self.ldap_host = ldap_host self.admin_user = admin_user - self.admin_pass= admin_pass + self.admin_pass = admin_pass self.base = base ldapserver = ldap3.Server(self.ldap_host,use_ssl=True) self.conn = ldap3.Connection(ldapserver, admin_user, admin_pass, auto_bind=True) class logbranch(): - def __init__(self, ldap_server: ldapbind, log_server: ldapbind): + def __init__(self, ldap_server: ldapbind, log_server: ldapbind = None): self.ldap_server = ldap_server self.log_server = log_server + if self.log_server == None: + dc = ldap_server.base.split(',') + self.dc = ','.join(dc[1:]) + self.log_server = ldapbind(ldap_server.ldap_host, ldap_server.admin_user, ldap_server.admin_pass, f'ou=log,{self.dc}') + self.base = self.ldap_server.base self.logbase = self.log_server.base - self.logconnection = self.log_server.conn self.ldapconnection = self.ldap_server.conn self.logconnection = self.log_server.conn + # Unique id of the log branch + self.id = self.getid() + # How many changes we applied to the LDAP server self.loaded = self.getloaded() # How many changes are recoreded in total self.total = self.gettotal() + def getid(self)->int: + self.logconnection.search(search_base=f'uid=id,{self.logbase}',search_filter='(objectClass=person)', attributes=['uidNumber']) + response = self.logconnection.response + if response == []: + id = random.randint(10**12, 10**13-1) + self.logconnection.add(f'uid=id,{self.logbase}', OBJECTCLASSES, { 'uid' : 'id', 'uidNumber' : id }) + return id + return int(response[0]['attributes']['uidNumber']) + def gettotal(self)->int: self.logconnection.search(search_base=f'uid=total,{self.logbase}',search_filter='(objectClass=person)', attributes=['uidNumber']) response = self.logconnection.response if response == []: - response = 0 - return int(response['attributes']['uidNumber']) + settotal(0) + return 0 + return int(response[0]['attributes']['uidNumber']) def settotal(self, newvalue: int): newvalue = int(newvalue) @@ -48,16 +66,18 @@ class logbranch(): response = self.logconnection.response if response == []: - return self.logconnection.add(f'uid=total,{self.logbase}', OBJECTCLASSES, { 'uid' : 'total', 'uidNumber' : 0 }) + self.logconnection.add(f'uid=total,{self.logbase}', OBJECTCLASSES, { 'uid' : 'total', 'uidNumber' : 0 }) - return self.logconnection.modify(f'uid=total,{self.logbase}', {'uidNumber' : (ldap3.MODIFY_REPLACE, [newvalue])}) + self.logconnection.modify(f'uid=total,{self.logbase}', {'uidNumber' : (ldap3.MODIFY_REPLACE, [newvalue])}) + return self.logconnection.response def getloaded(self): self.connection.search(search_base=f'uid=loaded,{self.logbase}',search_filter='(objectClass=person)', attributes=['uidNumber']) response = self.logconnection.response if response == []: - response = 0 - return int(self.connection.response['attributes']['uidNumber']) + setloaded(0) + return 0 + return int(self.connection.response[0]['attributes']['uidNumber']) def setloaded(self, newvalue: int): newvalue = int(newvalue) @@ -67,16 +87,17 @@ class logbranch(): response = self.logconnection.response if response == []: - return self.logconnection.add(f'uid=loaded,{self.logbase}', OBJECTCLASSES, { 'uid' : 'loaded', 'uidNumber' : 0 }) + self.logconnection.add(f'uid=loaded,{self.logbase}', OBJECTCLASSES, { 'uid' : 'loaded', 'uidNumber' : 0 }) - return self.logconnection.modify(f'uid=loaded,{self.logbase}', {'uidNumber' : (ldap3.MODIFY_REPLACE, [newvalue])}) + self.logconnection.modify(f'uid=loaded,{self.logbase}', {'uidNumber' : (ldap3.MODIFY_REPLACE, [newvalue])}) + return self.logconnection.response def refreshtotal(self): self.settotal(self.findtotal()) return self.gettotal() def findtotal(self): - for entry in range(self.gettotal(), MAXENTRIES): + for entry in range(self.gettotal() + 1, MAXENTRIES): self.logconnection.search(search_base=f'uid={entry},{self.logbase}',search_filter='(objectClass=person)', attributes=['uidNumber']) if self.logconnection.response == []: return entry - 1 @@ -84,8 +105,8 @@ class logbranch(): def applylogs(self): self.refreshtotal() - loaded = int(self.getloaded()) + 1 - total = int(self.gettotal()) + 1 + loaded = self.getloaded() + 1 + total = self.gettotal() + 1 for log_number in range(loaded, total): self.setlog(log_number) @@ -98,7 +119,7 @@ class logbranch(): def getlog(self, log_number: int): self.logconnection.search(search_base=f'uid={log_number},{self.logbase}',search_filter = '(objectClass=person)', attributes = USERATTRIBUTES) - return logconnection.response + return logconnection.response[0] def setlog(self, log_number: int): log = self.getlog(log_number) @@ -121,17 +142,44 @@ class logbranch(): return ldapconnection.response class logtree(): - def __init__(self, ldap_server: ldapbind): - self.ldap_server = ldap_server + def __init__(self): self.branches = [] def add(self, log_server: logbranch): - self.branches.append(log_server) + # Find ID of new branch + id = log_server.logconnection.search(search_base=f'{log_server.logbase}',search_filter='(objectClass=person)', attributes=['uidNumber']) + + remotebranches = [] + for branch in self.branches: + # Create base for log copies of the new log branch to all existing ones + branch.ldapconnection.search(search_base=f'ou=ldapsync{id},{branch.dc}',search_filter='(objectClass=organizationalUnit)', attributes=['ou']) + if branch.ldapconnection.response == []: + branch.ldapconnection.add(f'ou=ldapsync{id},{branch.dc}', ['top', 'organizationalUnit'], {'ou' : id}) + remotebranch = logbranch(branch.log_server) + remotebranch.logbase = f'ou=ldapsync{id},{branch.dc}' + remotebranches.append(remotebranch) + + branchdata = {'local' : log_server , 'remote' : remotebranches} + self.branches.append(branchdata) def remove(self, log_server: logbranch): - self.branches.remove(log_server) + for branch in self.branches: + if branch['local'] == log_server: + self.branches.remove(branch) + + def push(self): + for branch in self.branches: + localtotal = branch.gettotal() + for remotebranch in branch['remote']: + remotetotal = remotebranch.gettotal() + for log in range(remotetotal + 1, localtotal + 1): + locallog = branch['local'].getlog(log) + remotebranch.logconnection.add(f'uid={log},{remotebranch.logbase}', OBJECTCLASSES, locallog['attributes']) def sync(self): + self.push() for branch in self.branches: - branch.applylogs() + branch['local'].applylogs() + for localbranch in branch['remote']: + localbranch.applylogs()