diff --git a/tests/fixtures.py b/tests/fixtures.py index 7ff7e26d4408..7f05e3c6af8e 100644 --- a/tests/fixtures.py +++ b/tests/fixtures.py @@ -1,6 +1,7 @@ from utils import TEST_NETWORK, VALGRIND # noqa: F401,F403 -from pyln.testing.fixtures import directory, test_base_dir, test_name, chainparams, node_factory, bitcoind, teardown_checks, db_provider, executor, setup_logging, jsonschemas # noqa: F401,F403 +from pyln.testing.fixtures import directory, test_base_dir, test_name, chainparams, bitcoind, teardown_checks, db_provider, executor, setup_logging, jsonschemas # noqa: F401,F403 from pyln.testing import utils +from pyln.testing.utils import NodeFactory as _NodeFactory from utils import COMPAT from pathlib import Path @@ -11,19 +12,71 @@ import subprocess import tempfile import time +from pyln.testing.utils import env +from vls import ValidatingLightningSignerD + + +class NodeFactory(_NodeFactory): + """Make `use_vls` option reaches the `LightningNode.__init__` in + `NodeFactory` as node-level kwarg instead of being forwarded as a + lightningd CLI flag.""" + + def split_options(self, opts): + node_opts, cli_opts = super().split_options(opts) + if 'use_vls' in cli_opts: + node_opts['use_vls'] = cli_opts.pop('use_vls') + return node_opts, cli_opts @pytest.fixture def node_cls(): return LightningNode +# Override the default fixture to use the new `NodeFactory` which supports `use_vls` as a node-level option. +@pytest.fixture +def node_factory(request, directory, test_name, bitcoind, executor, db_provider, teardown_checks, node_cls, jsonschemas): + nf = NodeFactory( + request, + test_name, + bitcoind, + executor, + directory=directory, + db_provider=db_provider, + node_cls=node_cls, + jsonschemas=jsonschemas, + ) + + yield nf + ok, errs = nf.killall([not n.may_fail for n in nf.nodes]) + + for e in errs: + print(e.format()) + + if not ok: + raise Exception("At least one lightning exited with unexpected non-zero return code") + class LightningNode(utils.LightningNode): - def __init__(self, *args, **kwargs): + def __init__(self, node_id, *args, use_vls=None, **kwargs): # Yes, we really want to test the local development version, not # something in out path. kwargs["executable"] = "lightningd/lightningd" - utils.LightningNode.__init__(self, *args, **kwargs) + utils.LightningNode.__init__(self, node_id, *args, **kwargs) + + self.node_id = node_id + self.network = TEST_NETWORK + + if use_vls is True: + self.vls_mode = "cln:socket" + elif use_vls is False: + self.vls_mode = "cln:native" + else: + # use_vls=None (default) falls back to the VLS_MODE env var. + # Setting this env var causes all nodes use the same mode + self.vls_mode = env("VLS_MODE", "cln:native") + + self.use_vls = use_vls is not None + self.vlsd: ValidatingLightningSignerD | None = None # Avoid socket path name too long on Linux if os.uname()[0] == 'Linux' and \ @@ -61,6 +114,33 @@ def __init__(self, *args, **kwargs): accts_db = self.db.provider.get_db('', 'accounts', 0) self.daemon.opts['bookkeeper-db'] = accts_db.get_dsn() + def start(self, wait_for_bitcoind_sync=True, stderr_redir=False): + # We start the signer first, otherwise the lightningd startup hangs on the init message. + if self.use_vls: + self.vlsd = ValidatingLightningSignerD( + lightning_dir=self.lightning_dir, + node_id=self.node_id, + network=self.network, + ) + self.daemon.opts["subdaemon"] = f"hsmd:{self.vlsd.remote_socket}" + self.daemon.env["VLS_PORT"] = str(self.vlsd.port) + self.daemon.env["VLS_LSS"] = os.environ.get("LSS_URI", "") + import threading + threading.Timer(1, self.vlsd.start).start() + + utils.LightningNode.start( + self, + wait_for_bitcoind_sync=wait_for_bitcoind_sync, + stderr_redir=stderr_redir, + ) + + def stop(self, timeout: int = 10): + utils.LightningNode.stop(self, timeout=timeout) + if self.vlsd is not None: + rc = self.vlsd.stop(timeout=timeout) + print(f"VLSD2 exited with rc={rc}") + + class CompatLevel(object): """An object that encapsulates the compat-level of our build. diff --git a/tests/test_pay.py b/tests/test_pay.py index f73262c437da..4b17cb63a511 100644 --- a/tests/test_pay.py +++ b/tests/test_pay.py @@ -22,6 +22,18 @@ import unittest +@pytest.mark.openchannel('v1') +@pytest.mark.openchannel('v2') +def test_vls_simple(node_factory): + l1, l2 = node_factory.line_graph(2, opts={'use_vls': True}) + + inv = l2.rpc.invoice(123000, 'test_vls_simple', 'description')['bolt11'] + details = l1.dev_pay(inv, dev_use_shadow=False) + assert details['status'] == 'complete' + assert details['amount_msat'] == Millisatoshi(123000) + assert details['destination'] == l2.info['id'] + + @pytest.mark.openchannel('v1') @pytest.mark.openchannel('v2') def test_pay(node_factory): diff --git a/tests/vls.py b/tests/vls.py new file mode 100644 index 000000000000..3f8631290b17 --- /dev/null +++ b/tests/vls.py @@ -0,0 +1,157 @@ +from pyln.testing.utils import TailableProc, env, reserve_unused_port +import logging +import os +import json +from pathlib import Path +from enum import Enum +from subprocess import run, PIPE +from typing import Union +import sys +import time + +__VERSION__ = "0.0.1" + +logging.basicConfig( + level=logging.INFO, + format='[%(asctime)s] %(levelname)s: %(message)s', + handlers=[logging.StreamHandler(stream=sys.stdout)], +) + +def chunk_string(string: str, size: int): + for i in range(0, len(string), size): + yield string[i: i + size] + + +def ratelimit_output(output: str): + sys.stdout.reconfigure(encoding='utf-8') + for i in chunk_string(output, 1024): + sys.stdout.write(i) + sys.stdout.flush() + time.sleep(0.01) + + +class Logger: + """Redirect logging output to a json object or stdout as appropriate.""" + def __init__(self, capture: bool = False): + self.json_output = {"result": [], + "log": []} + self.capture = capture + + def str_esc(self, raw_string: str) -> str: + assert isinstance(raw_string, str) + return json.dumps(raw_string)[1:-1] + + def debug(self, to_log: str): + assert isinstance(to_log, str) or hasattr(to_log, "__repr__") + if logging.root.level > logging.DEBUG: + return + if self.capture: + self.json_output['log'].append(self.str_esc(f"DEBUG: {to_log}")) + else: + logging.debug(to_log) + + def info(self, to_log: str): + assert isinstance(to_log, str) or hasattr(to_log, "__repr__") + if logging.root.level > logging.INFO: + return + if self.capture: + self.json_output['log'].append(self.str_esc(f"INFO: {to_log}")) + else: + print(to_log) + + def warning(self, to_log: str): + assert isinstance(to_log, str) or hasattr(to_log, "__repr__") + if logging.root.level > logging.WARNING: + return + if self.capture: + self.json_output['log'].append(self.str_esc(f"WARNING: {to_log}")) + else: + logging.warning(to_log) + + def error(self, to_log: str): + assert isinstance(to_log, str) or hasattr(to_log, "__repr__") + if logging.root.level > logging.ERROR: + return + if self.capture: + self.json_output['log'].append(self.str_esc(f"ERROR: {to_log}")) + else: + logging.error(to_log) + + def add_result(self, result: Union[str, None]): + assert json.dumps(result), "result must be json serializable" + self.json_output["result"].append(result) + + def reply_json(self): + """json output to stdout with accumulated result.""" + if len(log.json_output["result"]) == 1 and \ + isinstance(log.json_output["result"][0], list): + # unpack sources output + log.json_output["result"] = log.json_output["result"][0] + output = json.dumps(log.json_output, indent=3) + '\n' + ratelimit_output(output) + + +log = Logger() + +repos = ["https://gitlab.com/lightning-signer/validating-lightning-signer.git"] + + +class ValidatingLightningSignerD(TailableProc): + def __init__(self, lightning_dir, node_id, network): + logging.info("Initializing ValidatingLightningSignerD") + log.info(f"Cloning repository into {lightning_dir}") + self.lightning_dir = lightning_dir + clone = run(['git', 'clone', repos[0]], cwd=self.lightning_dir, check=True, timeout=120) + signer_folder = repos[0].split("/")[-1].split(".git")[0] + vlsd_dir = Path(self.lightning_dir / signer_folder).resolve() + self.dir = vlsd_dir + self.port = reserve_unused_port() + self.rpc_port = reserve_unused_port() + + if clone.returncode != 0: + log.error(f"Failed to clone repository: {clone.stderr}") + else: + log.info(f"Successfully cloned repository: {clone.stdout}") + + cargo = run(['cargo', 'build'], cwd=self.dir, check=True, timeout=300) + if cargo.returncode != 0: + log.error(f"Failed to build vlsd: {cargo.stderr}") + else: + log.info("Successfully built vlsd") + + TailableProc.__init__(self, self.dir, verbose=True) + self.executable = env("REMOTE_SIGNER_CMD", str(Path(self.dir / "target" / "debug" / "vlsd").resolve())) + self.remote_socket = Path(self.dir / "target" / "debug" / "remote_hsmd_socket").resolve() + os.environ['ALLOWLIST'] = env( + 'REMOTE_SIGNER_ALLOWLIST', + 'contrib/remote_hsmd/TESTING_ALLOWLIST') + self.opts = [ + '--network={}'.format(network), + '--datadir={}'.format(self.dir), + '--connect=http://localhost:{}'.format(self.port), + '--rpc-server-port={}'.format(self.rpc_port), + '--integration-test', + ] + self.prefix = 'vlsd-%d' % (node_id) + + @property + def cmd_line(self): + return [self.executable] + self.opts + + def start(self, stdin=None, stdout_redir=True, stderr_redir=True, + wait_for_initialized=True): + TailableProc.start(self, stdin, stdout_redir, stderr_redir) + # We need to always wait for initialization + self.wait_for_log("vlsd git_desc") + logging.info("vlsd started") + + def stop(self, timeout=10): + logging.info("stopping vlsd") + rc = TailableProc.stop(self, timeout) + logging.info("vlsd stopped") + self.logs_catchup() + return rc + + def __del__(self): + self.logs_catchup() +