Skip to content

Commit

Permalink
#31: Start to adjust agent
Browse files Browse the repository at this point in the history
  • Loading branch information
mcdope committed Jul 14, 2024
1 parent 1aca921 commit 914896b
Showing 1 changed file with 38 additions and 30 deletions.
68 changes: 38 additions & 30 deletions tools/pamusb-agent
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,10 @@ from gi.repository import GLib, UDisks
import xml.etree.ElementTree as et

class HotPlugDevice:
def __init__(self, serial):
def __init__(self, serial, name):
self.__udi = None
self.__serial = serial
self.__name = name
self.__callbacks = []
self.__running = False

Expand Down Expand Up @@ -92,7 +93,7 @@ class HotPlugDevice:
return
self.__udi = udi
if self.__running:
[ cb('added') for cb in self.__callbacks ]
[ cb('added', self.__name) for cb in self.__callbacks ]

def __deviceRemoved(self, udi):
if self.__udi is None:
Expand All @@ -101,7 +102,7 @@ class HotPlugDevice:
return
self.__udi = None
if self.__running:
[ cb('removed') for cb in self.__callbacks ]
[ cb('removed', self.__name) for cb in self.__callbacks ]

class Log:
def __init__(self):
Expand Down Expand Up @@ -209,24 +210,27 @@ def userDeviceThread(user):
}
)

# @todo: adjust for multiple devices here, should be an array of devices
deviceName = user.find('device').text.strip()
device_names = {}
to_watch = {}

all_devices = doc.findall("devices/device")
user_devices = user.findall("device")
for device in user_devices:
device_names += device.get('id')

devices = doc.findall("devices/device")
deviceOK = False
for device in devices:
if device.get('id') == deviceName: # should loop all devices and monitor them all, no part should be devicename bound
for device in all_devices:
if device.get('id') in device_names:
to_watch += {"name": device.get('id'), "serial": device.get('serial')}
deviceOK = True
break

if not deviceOK:
logger.error('Device %s not found in configuration file.' % deviceName)
logger.error('Device(s) not found in configuration file.')
return 1

serial = device.find('serial').text.strip()
resumeTimestamp = datetime.datetime.min

def authChangeCallback(event):
def authChangeCallback(event, deviceName):
if event == 'removed':
nonlocal resumeTimestamp
currentTimestamp = datetime.datetime.now()
Expand Down Expand Up @@ -268,7 +272,7 @@ def userDeviceThread(user):
logger.info('Process exit code: %d' % (process.returncode))
logger.info('Process stdout: %s' % (process.stdout.decode()))
logger.info('Process stderr: %s' % (process.stderr.decode()))

else:
logger.info('No commands defined for unlock!')

Expand All @@ -280,28 +284,32 @@ def userDeviceThread(user):

def onSuspendOrResume(start, member=None):
nonlocal resumeTimestamp
nonlocal hpDev
nonlocal hpDevs

if start == True:
logger.info('Suspending user "%s"' % (userName))
resumeTimestamp = datetime.datetime.max
else:
logger.info('Resuming user "%s"' % (userName))
if hpDev.isDeviceConnected() == True:
logger.info('Device is connected for user "%s", unlocking' % (userName))
authChangeCallback('added')
for hpDev in hpDevs:
if start == True:
logger.info('Suspending user "%s"' % (userName))
resumeTimestamp = datetime.datetime.max
else:
logger.info('Resuming user "%s"' % (userName))
if hpDev.isDeviceConnected() == True:
logger.info('Device %s is connected for user "%s", unlocking' % (hpDev.__name, userName))
authChangeCallback('added')

resumeTimestamp = datetime.datetime.now()
resumeTimestamp = datetime.datetime.now()

login1Interface = login1ManagerDBusIface()
for signal in ['PrepareForSleep', 'PrepareForShutdown']:
login1Interface.connect_to_signal(signal, onSuspendOrResume, member_keyword='member')

hpDev = HotPlugDevice(serial)
hpDev.addCallback(authChangeCallback)
hpDevs = {}
for watch_this in to_watch:
hpDev = HotPlugDevice(watch_this.get('serial'), watch_this.get('name'))
hpDev.addCallback(authChangeCallback)
hpDevs += hpDev

logger.info('Watching device "%s" for user "%s"' % (deviceName, userName))
hpDev.run()
logger.info('Watching device "%s" for user "%s"' % (watch_this.get('name'), userName))
hpDev.run()

udisks = UDisks.Client.new_sync()
udisksObjectManager = udisks.get_object_manager()
Expand Down Expand Up @@ -356,10 +364,10 @@ if options['daemon'] and os.fork():
sys.exit(0)

def sig_handler(sig, frame):
logger.info('Stopping agent.')
sys.exit(0)
logger.info('Stopping agent.')
sys.exit(0)

sys_signals = ['SIGINT', 'SIGTERM', 'SIGTSTP', 'SIGTTIN', 'SIGTTOU']

for i in sys_signals:
signal.signal(getattr(signal, i), sig_handler)
signal.signal(getattr(signal, i), sig_handler)

0 comments on commit 914896b

Please sign in to comment.