diff --git a/tools/pamusb-agent b/tools/pamusb-agent index 4ab4a535..28d25ce9 100755 --- a/tools/pamusb-agent +++ b/tools/pamusb-agent @@ -40,9 +40,10 @@ from gi.repository import GLib, UDisks import xml.etree.ElementTree as et class HotPlugDevice: - def __init__(self, serial): + def __init__(self, serial, name): self.__udi = None self.__serial = serial + self.__name = name self.__callbacks = [] self.__running = False @@ -92,7 +93,7 @@ class HotPlugDevice: return self.__udi = udi if self.__running: - [ cb('added') for cb in self.__callbacks ] + [ cb('added', self.__name) for cb in self.__callbacks ] def __deviceRemoved(self, udi): if self.__udi is None: @@ -101,7 +102,7 @@ class HotPlugDevice: return self.__udi = None if self.__running: - [ cb('removed') for cb in self.__callbacks ] + [ cb('removed', self.__name) for cb in self.__callbacks ] class Log: def __init__(self): @@ -209,24 +210,27 @@ def userDeviceThread(user): } ) - # @todo: adjust for multiple devices here, should be an array of devices - deviceName = user.find('device').text.strip() + device_names = {} + to_watch = {} + + all_devices = doc.findall("devices/device") + user_devices = user.findall("device") + for device in user_devices: + device_names += device.get('id') - devices = doc.findall("devices/device") deviceOK = False - for device in devices: - if device.get('id') == deviceName: # should loop all devices and monitor them all, no part should be devicename bound + for device in all_devices: + if device.get('id') in device_names: + to_watch += {"name": device.get('id'), "serial": device.get('serial')} deviceOK = True - break if not deviceOK: - logger.error('Device %s not found in configuration file.' % deviceName) + logger.error('Device(s) not found in configuration file.') return 1 - serial = device.find('serial').text.strip() resumeTimestamp = datetime.datetime.min - def authChangeCallback(event): + def authChangeCallback(event, deviceName): if event == 'removed': nonlocal resumeTimestamp currentTimestamp = datetime.datetime.now() @@ -268,7 +272,7 @@ def userDeviceThread(user): logger.info('Process exit code: %d' % (process.returncode)) logger.info('Process stdout: %s' % (process.stdout.decode())) logger.info('Process stderr: %s' % (process.stderr.decode())) - + else: logger.info('No commands defined for unlock!') @@ -280,28 +284,32 @@ def userDeviceThread(user): def onSuspendOrResume(start, member=None): nonlocal resumeTimestamp - nonlocal hpDev + nonlocal hpDevs - if start == True: - logger.info('Suspending user "%s"' % (userName)) - resumeTimestamp = datetime.datetime.max - else: - logger.info('Resuming user "%s"' % (userName)) - if hpDev.isDeviceConnected() == True: - logger.info('Device is connected for user "%s", unlocking' % (userName)) - authChangeCallback('added') + for hpDev in hpDevs: + if start == True: + logger.info('Suspending user "%s"' % (userName)) + resumeTimestamp = datetime.datetime.max + else: + logger.info('Resuming user "%s"' % (userName)) + if hpDev.isDeviceConnected() == True: + logger.info('Device %s is connected for user "%s", unlocking' % (hpDev.__name, userName)) + authChangeCallback('added') - resumeTimestamp = datetime.datetime.now() + resumeTimestamp = datetime.datetime.now() login1Interface = login1ManagerDBusIface() for signal in ['PrepareForSleep', 'PrepareForShutdown']: login1Interface.connect_to_signal(signal, onSuspendOrResume, member_keyword='member') - hpDev = HotPlugDevice(serial) - hpDev.addCallback(authChangeCallback) + hpDevs = {} + for watch_this in to_watch: + hpDev = HotPlugDevice(watch_this.get('serial'), watch_this.get('name')) + hpDev.addCallback(authChangeCallback) + hpDevs += hpDev - logger.info('Watching device "%s" for user "%s"' % (deviceName, userName)) - hpDev.run() + logger.info('Watching device "%s" for user "%s"' % (watch_this.get('name'), userName)) + hpDev.run() udisks = UDisks.Client.new_sync() udisksObjectManager = udisks.get_object_manager() @@ -356,10 +364,10 @@ if options['daemon'] and os.fork(): sys.exit(0) def sig_handler(sig, frame): - logger.info('Stopping agent.') - sys.exit(0) + logger.info('Stopping agent.') + sys.exit(0) sys_signals = ['SIGINT', 'SIGTERM', 'SIGTSTP', 'SIGTTIN', 'SIGTTOU'] for i in sys_signals: - signal.signal(getattr(signal, i), sig_handler) + signal.signal(getattr(signal, i), sig_handler)