Skip to content

Commit 7246244

Browse files
committed
Improve udev installation
Add command line function to install udev rule, and tell user where to find the file if they want to install it manually.
1 parent 2f425f9 commit 7246244

File tree

4 files changed

+72
-10
lines changed

4 files changed

+72
-10
lines changed
File renamed without changes.

pslab/cli.py

+59
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,17 @@
1515
import argparse
1616
import csv
1717
import json
18+
import platform
19+
import os.path
20+
import shutil
1821
import sys
1922
import time
2023
from itertools import zip_longest
2124
from typing import List, Tuple
2225

2326
import numpy as np
2427

28+
import pslab
2529
import pslab.protocol as CP
2630
from pslab.instrument.logic_analyzer import LogicAnalyzer
2731
from pslab.instrument.oscilloscope import Oscilloscope
@@ -224,6 +228,10 @@ def main(args: argparse.Namespace):
224228
args : :class:`argparse.Namespace`
225229
Parsed arguments.
226230
"""
231+
if args.function == "install":
232+
install(args)
233+
return
234+
227235
handler = SerialHandler(port=args.port)
228236

229237
if args.function == "collect":
@@ -456,4 +464,55 @@ def cmdline(args: List[str] = None):
456464
add_collect_args(subparser)
457465
add_wave_args(subparser)
458466
add_pwm_args(subparser)
467+
add_install_args(subparser)
459468
main(parser.parse_args(args))
469+
470+
471+
def install(args: argparse.Namespace):
472+
"""Install udev rule on Linux.
473+
474+
Parameters
475+
----------
476+
args : :class:`argparse.Namespace`
477+
Parsed arguments.
478+
"""
479+
if not platform.system() == "Linux":
480+
print(f"Installation not required on {platform.system()}.")
481+
return
482+
else:
483+
try:
484+
SerialHandler.check_udev()
485+
except OSError:
486+
_install()
487+
return
488+
489+
if args.force:
490+
_install()
491+
return
492+
493+
print("udev rule already installed.")
494+
495+
496+
def _install():
497+
udev_rules = os.path.join(pslab.__path__[0], "99-pslab.rules")
498+
target = "/etc/udev/rules.d/99-pslab.rules"
499+
shutil.copyfile(udev_rules, target)
500+
return
501+
502+
503+
def add_install_args(subparser: argparse._SubParsersAction):
504+
"""Add arguments for install function to ArgumentParser.
505+
506+
Parameters
507+
----------
508+
subparser : :class:`argparse._SubParsersAction`
509+
SubParser to add other arguments related to install function.
510+
"""
511+
install = subparser.add_parser("install")
512+
install.add_argument(
513+
"-f",
514+
"--force",
515+
action="store_true",
516+
default=False,
517+
help="Overwrite existing udev rules.",
518+
)

pslab/serial_handler.py

+9-6
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import serial
1919
from serial.tools import list_ports
2020

21+
import pslab
2122
import pslab.protocol as CP
2223

2324
logger = logging.getLogger(__name__)
@@ -79,7 +80,7 @@ def __init__(
7980
baudrate: int = 1000000,
8081
timeout: float = 1.0,
8182
):
82-
self._check_udev()
83+
self.check_udev()
8384
self.version = ""
8485
self._log = b""
8586
self._logging = False
@@ -98,7 +99,8 @@ def __init__(
9899
self.connected = self.interface.is_open
99100

100101
@staticmethod
101-
def _check_udev():
102+
def check_udev():
103+
"""Check if udev rule is installed on Linux."""
102104
if platform.system() == "Linux":
103105
udev_paths = [
104106
"/run/udev/rules.d/",
@@ -110,11 +112,11 @@ def _check_udev():
110112
if os.path.isfile(udev_rules):
111113
break
112114
else:
113-
e = (
115+
raise OSError(
114116
"A udev rule must be installed to access the PSLab. "
115-
+ "Please copy 99-pslab.rules to /etc/udev/rules.d/."
117+
"Please run 'pslab install' as root, or copy "
118+
f"{pslab.__path__[0]}/99-pslab.rules to {udev_paths[1]}."
116119
)
117-
raise OSError(e)
118120

119121
@staticmethod
120122
def _list_ports() -> List[str]:
@@ -401,7 +403,8 @@ def __init__(
401403
super().__init__(port, baudrate, timeout)
402404

403405
@staticmethod
404-
def _check_udev():
406+
def check_udev():
407+
"""See :meth:`SerialHandler.check_udev`."""
405408
pass
406409

407410
def connect(

tests/test_serial_handler.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def mock_serial(mocker):
3030

3131
@pytest.fixture
3232
def mock_handler(mocker, mock_serial, mock_list_ports):
33-
mocker.patch("pslab.serial_handler.SerialHandler._check_udev")
33+
mocker.patch("pslab.serial_handler.SerialHandler.check_udev")
3434
mock_list_ports.grep.return_value = mock_ListPortInfo()
3535
return SerialHandler()
3636

@@ -47,21 +47,21 @@ def test_detect(mocker, mock_serial, mock_list_ports):
4747

4848
def test_connect_scan_port(mocker, mock_serial, mock_list_ports):
4949
mock_list_ports.grep.return_value = mock_ListPortInfo()
50-
mocker.patch("pslab.serial_handler.SerialHandler._check_udev")
50+
mocker.patch("pslab.serial_handler.SerialHandler.check_udev")
5151
SerialHandler()
5252
mock_serial().open.assert_called()
5353

5454

5555
def test_connect_scan_failure(mocker, mock_serial, mock_list_ports):
5656
mock_list_ports.grep.return_value = mock_ListPortInfo(found=False)
57-
mocker.patch("pslab.serial_handler.SerialHandler._check_udev")
57+
mocker.patch("pslab.serial_handler.SerialHandler.check_udev")
5858
with pytest.raises(SerialException):
5959
SerialHandler()
6060

6161

6262
def test_connect_multiple_connected(mocker, mock_serial, mock_list_ports):
6363
mock_list_ports.grep.return_value = mock_ListPortInfo(multiple=True)
64-
mocker.patch("pslab.serial_handler.SerialHandler._check_udev")
64+
mocker.patch("pslab.serial_handler.SerialHandler.check_udev")
6565
with pytest.raises(RuntimeError):
6666
SerialHandler()
6767

0 commit comments

Comments
 (0)