Source code for renku.core.session.session

# Copyright Swiss Data Science Center (SDSC). A partnership between
# École Polytechnique Fédérale de Lausanne (EPFL) and
# Eidgenössische Technische Hochschule Zürich (ETHZ).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Interactive session business logic."""

import os
import shutil
import textwrap
from pathlib import Path
from typing import List, NamedTuple, Optional

from pydantic import ConfigDict, validate_call

from renku.core import errors
from renku.core.config import get_value
from renku.core.plugin.session import get_supported_hibernating_session_providers, get_supported_session_providers
from renku.core.session.utils import get_image_repository_host, get_renku_project_name
from renku.core.util import communication
from renku.core.util.os import safe_read_yaml
from renku.core.util.ssh import SystemSSHConfig, generate_ssh_keys
from renku.domain_model.session import IHibernatingSessionProvider, ISessionProvider, Session, SessionStopStatus


def _safe_get_provider(provider: str) -> ISessionProvider:
    try:
        return next(p for p in get_supported_session_providers() if p.name == provider)
    except StopIteration:
        raise errors.ParameterError(f"Session provider '{provider}' is not available!")


[docs]class SessionList(NamedTuple): """Session list return.""" sessions: List[Session] all_local: bool warning_messages: List[str]
[docs]@validate_call(config=ConfigDict(arbitrary_types_allowed=True)) def search_sessions(name: str, provider: Optional[str] = None) -> List[str]: """Get all sessions that their name starts with the given name. Args: name(str): The name to search for. provider(Optional[str]): Name of the session provider to use (Default value = None). Returns: All sessions whose name starts with ``name``. """ sessions = session_list(provider=provider).sessions name = name.lower() return [s.id for s in sessions if s.id.lower().startswith(name)]
[docs]@validate_call(config=ConfigDict(arbitrary_types_allowed=True)) def search_session_providers(name: str) -> List[str]: """Get all session providers that their name starts with the given name. Args: name(str): The name to search for. Returns: All session providers whose name starts with ``name``. """ from renku.core.plugin.session import get_supported_session_providers name = name.lower() return [p.name for p in get_supported_session_providers() if p.name.lower().startswith(name)]
[docs]@validate_call(config=ConfigDict(arbitrary_types_allowed=True)) def search_hibernating_session_providers(name: str) -> List[str]: """Get all session providers that support hibernation and their name starts with the given name. Args: name(str): The name to search for. Returns: All session providers whose name starts with ``name``. """ from renku.core.plugin.session import get_supported_hibernating_session_providers name = name.lower() return [p.name for p in get_supported_hibernating_session_providers() if p.name.lower().startswith(name)]
[docs]@validate_call(config=ConfigDict(arbitrary_types_allowed=True)) def session_list(*, provider: Optional[str] = None) -> SessionList: """List interactive sessions. Args: provider(Optional[str]): Name of the session provider to use (Default value = None). Returns: The list of sessions, whether they're all local sessions and potential warnings raised. """ def list_sessions(session_provider: ISessionProvider) -> List[Session]: try: return session_provider.session_list(project_name=project_name) except errors.RenkulabSessionGetUrlError: if provider: raise return [] project_name = get_renku_project_name() providers = [_safe_get_provider(provider)] if provider else get_supported_session_providers() all_sessions = [] warning_messages = [] all_local = True for session_provider in sorted(providers, key=lambda p: p.priority): try: sessions = list_sessions(session_provider) except errors.RenkuException as e: warning_messages.append(f"Cannot get sessions list from '{session_provider.name}': {e}") else: if session_provider.is_remote_provider(): all_local = False all_sessions.extend(sessions) return SessionList(all_sessions, all_local, warning_messages)
[docs]@validate_call(config=ConfigDict(arbitrary_types_allowed=True)) def session_start( config_path: Optional[str], provider: str, image_name: Optional[str] = None, cpu_request: Optional[float] = None, mem_request: Optional[str] = None, disk_request: Optional[str] = None, gpu_request: Optional[str] = None, **kwargs, ): """Start interactive session. Args: config_path(str, optional): Path to config YAML. provider(str, optional): Name of the session provider to use. image_name(str, optional): Image to start. cpu_request(float, optional): Number of CPUs to request. mem_request(str, optional): Size of memory to request. disk_request(str, optional): Size of disk to request (if supported by provider). gpu_request(str, optional): Number of GPUs to request. """ from renku.domain_model.project_context import project_context if project_context.repository.head.detached: raise errors.SessionStartError("Cannot start a session from a detached HEAD. Check out a branch first.") # NOTE: The Docker client in Python requires the parameters below to be a list and will fail with a tuple. # Click will convert parameters with the flag "many" set to True to tuples. kwargs["security_opt"] = list(kwargs.get("security_opt", [])) kwargs["device_cgroup_rules"] = list(kwargs.get("device_cgroup_rules", [])) pinned_image = get_value("interactive", "image") if pinned_image and image_name is None: image_name = pinned_image provider_api = _safe_get_provider(provider) config = safe_read_yaml(config_path) if config_path else dict() provider_api.pre_start_checks(**kwargs) project_name = get_renku_project_name() if image_name is None: tag = project_context.repository.head.commit.hexsha[:7] repo_host = get_image_repository_host() image_name = f"{project_name.lower()}:{tag}" if repo_host: image_name = f"{repo_host}/{image_name}" if image_name.lower() != image_name: raise errors.SessionStartError(f"Image name '{image_name}' cannot contain upper-case letters.") force_build_image = provider_api.force_build_image(**kwargs) if not force_build_image and not provider_api.find_image(image_name, config): communication.confirm( f"The container image '{image_name}' does not exist. Would you like to build it using {provider}?", abort=True, ) force_build_image = True if force_build_image: with communication.busy(msg=f"Building image {image_name}"): provider_api.build_image(project_context.dockerfile_path.parent, image_name, config) communication.echo(f"Image {image_name} built successfully.") # set resource settings cpu_limit = cpu_request or get_value("interactive", "cpu_request") if cpu_limit is not None: try: cpu_limit = float(cpu_limit) except ValueError: raise errors.SessionStartError(f"Invalid value for cpu_request (must be float): {cpu_limit}") disk_limit = disk_request or get_value("interactive", "disk_request") mem_limit = mem_request or get_value("interactive", "mem_request") gpu = gpu_request or get_value("interactive", "gpu_request") with communication.busy(msg="Waiting for session to start..."): provider_message, warning_message = provider_api.session_start( config=config, project_name=project_name, image_name=image_name, cpu_request=cpu_limit, mem_request=mem_limit, disk_request=disk_limit, gpu_request=gpu, **kwargs, ) if warning_message: communication.warn(warning_message) communication.echo(provider_message)
[docs]@validate_call(config=ConfigDict(arbitrary_types_allowed=True)) def session_stop(session_name: Optional[str], stop_all: bool = False, provider: Optional[str] = None): """Stop interactive session. Args: session_name(Optional[str]): Name of the session to open. stop_all(bool): Whether to stop all sessions or just the specified one. provider(Optional[str]): Name of the session provider to use. """ def stop_sessions(session_provider: ISessionProvider) -> SessionStopStatus: try: return session_provider.session_stop( project_name=project_name, session_name=session_name, stop_all=stop_all ) except errors.RenkulabSessionGetUrlError as e: if provider: raise communication.warn(f"Didn't stop any renkulab sessions: {e}") return SessionStopStatus.SUCCESSFUL except errors.DockerError as e: if provider: raise communication.warn(f"Didn't stop any docker sessions: {e}") return SessionStopStatus.SUCCESSFUL session_detail = "all sessions" if stop_all else f"session {session_name}" if session_name else "session" project_name = get_renku_project_name() providers = [_safe_get_provider(provider)] if provider else get_supported_session_providers() statues = [] warning_messages = [] with communication.busy(msg=f"Waiting for {session_detail} to stop..."): for session_provider in sorted(providers, key=lambda p: p.priority): try: status = stop_sessions(session_provider) except errors.RenkuException as e: warning_messages.append(f"Cannot stop sessions in provider '{session_provider.name}': {e}") else: statues.append(status) # NOTE: The given session name was stopped; don't continue if session_name and not stop_all and status == SessionStopStatus.SUCCESSFUL: break if warning_messages: for message in warning_messages: communication.warn(message) if not statues: return elif all(s == SessionStopStatus.NO_ACTIVE_SESSION for s in statues): raise errors.ParameterError("There are no running sessions.") elif session_name and not any(s == SessionStopStatus.SUCCESSFUL for s in statues): raise errors.ParameterError(f"Could not find '{session_name}' among the running sessions.") elif any(s == SessionStopStatus.FAILED for s in statues): raise errors.ParameterError("Cannot stop some sessions") elif not session_name and not any(s == SessionStopStatus.SUCCESSFUL for s in statues): raise errors.ParameterError("Session name is missing")
[docs]@validate_call(config=ConfigDict(arbitrary_types_allowed=True)) def session_open(session_name: Optional[str], provider: Optional[str] = None, **kwargs): """Open interactive session in the browser. Args: session_name(Optional[str]): Name of the session to open. provider(Optional[str]): Name of the session provider to use. """ providers = [_safe_get_provider(provider)] if provider else get_supported_session_providers() project_name = get_renku_project_name() for session_provider in providers: if session_provider.session_open(project_name, session_name, **kwargs): return if session_name: raise errors.ParameterError(f"Could not find '{session_name}' among the running sessions.") else: raise errors.ParameterError("Session name is missing")
[docs]@validate_call(config=ConfigDict(arbitrary_types_allowed=True)) def ssh_setup(existing_key: Optional[Path] = None, force: bool = False): """Setup SSH keys for SSH connections to sessions. Args: existing_key(Path, optional): Existing private key file to use instead of generating new ones. force(bool): Whether to prompt before overwriting keys or not """ if not shutil.which("ssh"): raise errors.SSHNotFoundError() system_config = SystemSSHConfig() include_string = f"Include {system_config.renku_ssh_root}/*.conf\n\n" if include_string not in system_config.ssh_config.read_text(): with system_config.ssh_config.open(mode="r+") as f: content = f.read() f.seek( 0, 0 ) # NOTE: We need to add 'Include' before any 'Host' entry, otherwise it is included as part of a host f.write(include_string + content) if not existing_key and not force and system_config.is_configured: communication.confirm(f"Keys already configured for host {system_config.renku_host}. Overwrite?", abort=True) if existing_key: communication.info("Linking existing keys") existing_public_key = existing_key.parent / (existing_key.name + ".pub") if not existing_key.exists() or not existing_public_key.exists(): raise errors.KeyNotFoundError( f"Couldn't find private key '{existing_key}' or public key '{existing_public_key}'." ) if system_config.keyfile.exists(): system_config.keyfile.unlink() if system_config.public_keyfile.exists(): system_config.public_keyfile.unlink() os.symlink(existing_key, system_config.keyfile) os.symlink(existing_public_key, system_config.public_keyfile) else: communication.info("Generating keys") keys = generate_ssh_keys() system_config.keyfile.touch(mode=0o600) system_config.public_keyfile.touch(mode=0o644) with system_config.keyfile.open( "wt", ) as f: f.write(keys.private_key) with system_config.public_keyfile.open("wt") as f: f.write(keys.public_key) communication.info("Writing SSH config") with system_config.jumphost_file.open(mode="wt") as f: # NOTE: The * at the end of the jumphost name hides it from VSCode content = textwrap.dedent( f""" Host jumphost-{system_config.renku_host}* HostName {system_config.renku_host} Port 2022 User jovyan """ ) f.write(content) communication.warn( "This command does not add any public SSH keys to your project. " "Keys have to be added manually or by using the 'renku session start' command with the '--ssh' flag." )
[docs]@validate_call(config=ConfigDict(arbitrary_types_allowed=True)) def session_pause(session_name: Optional[str], provider: Optional[str] = None, **kwargs): """Pause an interactive session. Args: session_name(Optional[str]): Name of the session. provider(Optional[str]): Name of the session provider to use. """ def pause(session_provider: IHibernatingSessionProvider) -> SessionStopStatus: try: return session_provider.session_pause(project_name=project_name, session_name=session_name) except errors.RenkulabSessionGetUrlError: if provider: raise return SessionStopStatus.FAILED project_name = get_renku_project_name() if provider: session_provider = _safe_get_provider(provider) if session_provider is None: raise errors.ParameterError(f"Provider '{provider}' not found") elif not isinstance(session_provider, IHibernatingSessionProvider): raise errors.ParameterError(f"Provider '{provider}' doesn't support pausing sessions") providers = [session_provider] else: providers = get_supported_hibernating_session_providers() session_message = f"session {session_name}" if session_name else "session" statues = [] warning_messages = [] with communication.busy(msg=f"Waiting for {session_message} to pause..."): for session_provider in sorted(providers, key=lambda p: p.priority): try: status = pause(session_provider) # type: ignore except errors.RenkuException as e: warning_messages.append(f"Cannot pause sessions in provider '{session_provider.name}': {e}") else: statues.append(status) # NOTE: The given session name was stopped; don't continue if session_name and status == SessionStopStatus.SUCCESSFUL: break if warning_messages: for message in warning_messages: communication.warn(message) if not statues: return elif all(s == SessionStopStatus.NO_ACTIVE_SESSION for s in statues): raise errors.ParameterError("There are no running sessions.") elif session_name and not any(s == SessionStopStatus.SUCCESSFUL for s in statues): raise errors.ParameterError(f"Could not find '{session_name}' among the running sessions.") elif not session_name and not any(s == SessionStopStatus.SUCCESSFUL for s in statues): raise errors.ParameterError("Session name is missing")
[docs]@validate_call(config=ConfigDict(arbitrary_types_allowed=True)) def session_resume(session_name: Optional[str], provider: Optional[str] = None, **kwargs): """Resume a paused session. Args: session_name(Optional[str]): Name of the session. provider(Optional[str]): Name of the session provider to use. """ project_name = get_renku_project_name() if provider: session_provider = _safe_get_provider(provider) if session_provider is None: raise errors.ParameterError(f"Provider '{provider}' not found") elif not isinstance(session_provider, IHibernatingSessionProvider): raise errors.ParameterError(f"Provider '{provider}' doesn't support pausing/resuming sessions") providers = [session_provider] else: providers = get_supported_hibernating_session_providers() session_message = f"session {session_name}" if session_name else "session" with communication.busy(msg=f"Waiting for {session_message} to resume..."): for session_provider in providers: if session_provider.session_resume(project_name, session_name, **kwargs): # type: ignore return if session_name: raise errors.ParameterError(f"Could not find '{session_name}' among the sessions.") else: raise errors.ParameterError("Session name is missing")