import os
import logging
from threading import Thread, Event
from datetime import datetime
from collections import deque
from google.protobuf import json_format
import grpc
from . import proto
try:
from pyats.log.utils import banner
from pyats.connections import BaseConnection
from pyats.utils.secret_strings import to_plaintext
from unicon.sshutils import sshtunnel
except ImportError:
# Standalone without pyats install
class BaseConnection:
class dev:
def __init__(self, dev_os):
self.os = dev_os
def __init__(self, device_os, **kwargs):
self.connection_info = {'protocol': 'gnmi'}
self._device = self.dev(device_os)
self.connection_info.update(kwargs)
def banner(string):
return string
def to_plaintext(string):
return string
# create a logger for this module
log = logging.getLogger(__name__)
class gNMIException(IOError):
pass
[docs]class GnmiNotification(Thread):
"""Thread listening for event notifications from the device."""
def __init__(self, device, response, **request):
Thread.__init__(self)
self.device = device
self._stop_event = Event()
self.log = logging.getLogger(__name__)
self.log.setLevel(logging.DEBUG)
self.request = request
self.responses = response
@property
def request(self):
return self._request
@request.setter
def request(self, request=None):
if request is None:
request = {}
self.returns = request.get('returns')
self.response_verify = request.get('verifier')
self.decode_response = request.get('decode')
self.namespace = request.get('namespace')
self.sub_mode = request['format'].get('sub_mode', 'SAMPLE')
self.encoding = request['format'].get('encoding', 'PROTO')
self.sample_interval = request['format'].get('sample_interval', 10)
self.stream_max = request['format'].get('stream_max', 0)
self.time_delta = 0
self.result = None
self.event_triggered = False
self._request = request
[docs] def process_opfields(self, response):
"""Decode response and verify result.
Decoder callback returns desired format of response.
Verify callback returns verification of expected results.
Args:
response (proto.gnmi_pb2.Notification): Contains updates that
have changes since last timestamp.
"""
subscribe_resp = json_format.MessageToDict(response)
updates = subscribe_resp['update']
for update in updates['update']:
resp = self.decode_response(update, self.namespace)
if self.event_triggered:
if resp:
if not self.returns:
self.log.error('No notification values to check')
self.result = False
self.stop()
else:
self.result = self.response_verify(
resp, self.returns.copy())
else:
self.log.error('No values in subscribe response')
[docs] def run(self):
"""Check for inbound notifications."""
t1 = datetime.now()
self.log.info('\nSubscribe notification active\n{0}'.format(29 * '='))
try:
for response in self.responses:
self.log.info(response)
if self.stopped():
self.time_delta = self.stream_max
self.log.info("Terminating notification thread")
break
if self.stream_max:
t2 = datetime.now()
td = t2 - t1
self.time_delta = td.seconds
if td.seconds > self.stream_max:
self.stop()
break
if response.HasField('sync_response'):
self.log.info('Subscribe syncing response')
if response.HasField('update'):
self.log.info('\nSubscribe response:\n{0}\n{1}'.format(
19 * '=', str(response)))
self.process_opfields(response)
except Exception as exc:
msg = ''
if hasattr(exc, 'details'):
msg += f'details: {exc.details()}'
if hasattr(exc, 'debug_error_string'):
msg += exc.debug_error_string()
if not msg:
msg = str(exc)
self.result = msg
[docs] def stop(self):
self.log.info("Stopping notification stream")
self._stop_event.set()
[docs] def stopped(self):
return self._stop_event.is_set()
class CiscoAuthPlugin(grpc.AuthMetadataPlugin):
"""A plugin which adds username/password metadata to each call."""
def __init__(self, username, password):
super(CiscoAuthPlugin, self).__init__()
self.username = username
self.password = password
def __call__(self, context, callback):
callback(
[("username", self.username), ("password", self.password)],
None
)
class GnmiLogHandler(logging.Handler):
@property
def gnmi_session(self):
return self._gnmi_session
@gnmi_session.setter
def gnmi_session(self, session):
self._gnmi_session = session
def emit(self, record):
self.gnmi_session.results.append(record.msg)
[docs]class Gnmi(BaseConnection):
"""Session handling for gNMI connections.
Can be used with pyATS same as yang.connector.Netconf is used or
can be used as a standalone module.
Methods:
capabilities(): gNMI Capabilities.\n
set(dict): gNMI Set. Input is namespace, xpath/value pairs.\n
get(dict): gNMI Get mode='STATE'. Input xpath/value pairs (value optional).\n
subscribe(dict): gNMI Subscribe. Input xpath/value pairs and format.\n
pyATS Examples:
>>> from pyats.topology import loader
>>> from yang.connector.gnmi import Gnmi
>>> testbed=loader.load('testbed.static.yaml')
>>> device=testbed.devices['uut']
>>> device.connect(alias='gnmi', via='yang2')
>>> #####################
>>> # Capabilities #
>>> #####################
>>> resp=device.capabilities()
>>> resp.gNMI_version
'0.7.0'
>>> #####################
>>> # Get example #
>>> #####################
>>> from yang.connector import proto
>>> request = proto.gnmi_pb2.GetRequest()
>>> request.type = proto.gnmi_pb2.GetRequest.DataType.Value('ALL')
>>> request.encoding = proto.gnmi_pb2.Encoding.Value('JSON_IETF')
>>> path = proto.gnmi_pb2.Path()
>>> path1, path2, path3, path4 = (
proto.gnmi_pb2.PathElem(),
proto.gnmi_pb2.PathElem(),
proto.gnmi_pb2.PathElem(),
proto.gnmi_pb2.PathElem()
)
>>> path1.name, path2.name, path3.name, path4.name = (
'syslog',
'messages',
'message',
'node-name'
)
>>> path.elem.extend([path1, path2, path3, path4])
>>> request.path.append(path)
>>> resp = device.gnmi.get(request)
>>> print(resp)
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.device = kwargs.get('device')
self.dev_args = self.connection_info
if self.dev_args.get('protocol', '') != 'gnmi':
msg = 'Invalid protocol {0}'.format(self.dev_args.get('protocol', ''))
raise TypeError(msg)
self.active_notifications = {}
root = None
chain = None
private_key = None
self.channel = None
self.results = deque()
self.metadata = None
@property
def connected(self):
"""Return True if session is connected."""
return self.service
@property
def gnmi(self):
"""Helper method to keep backwards compatibility.
Returns:
Gnmi: self
"""
return self
[docs] def connect(self):
"""Connect to device using gNMI and get capabilities.
Raises:
gNMIException: No gNMI capabilities returned by device.
"""
dev_args = self.dev_args
username = dev_args.get('username', '')
password = dev_args.get('password', '')
if dev_args.get('custom_log', ''):
self.log = dev_args.get('custom_log')
else:
self.log = log
self.log.setLevel(logging.INFO)
gnmi_log_handler = GnmiLogHandler()
gnmi_log_handler.gnmi_session = self
gnmi_log_handler.setLevel(logging.INFO)
log.addHandler(gnmi_log_handler)
self.log.addHandler(gnmi_log_handler)
if not username or not password:
creds = dev_args.get('credentials', '')
if not creds:
raise KeyError('No credentials found for testbed')
if 'gnmi' not in creds:
log.info('Credentials used from {0}'.format(next(iter(creds))))
gnmi_uname_pwd = creds.get('')
if not gnmi_uname_pwd:
raise KeyError('No credentials found for gNMI')
username = gnmi_uname_pwd.get('username', '')
password = gnmi_uname_pwd.get('password', '')
if not username or not password:
raise KeyError('No credentials found for gNMI testbed')
password = to_plaintext(password)
if 'sshtunnel' in dev_args:
try:
tunnel_port = sshtunnel.auto_tunnel_add(self.device, self.via)
if tunnel_port:
host = self.device.connections[self.via] \
.sshtunnel.tunnel_ip
port = tunnel_port
except AttributeError as err:
raise AttributeError("Cannot add ssh tunnel. \
Connection %s may not have ip/host or port.\n%s"
% (self.via, err))
else:
host = dev_args.get('host') or dev_args.get('ip')
port = str(dev_args.get('port'))
target = '{0}:{1}'.format(host, port)
options = [('grpc.max_receive_message_length', 1000000000)]
# Gather certificate settings
root = dev_args.get('root_certificate')
if not root:
root = None
if root and os.path.isfile(root):
root = open(root, 'rb').read()
chain = dev_args.get('certificate_chain')
if not chain:
chain = None
if chain and os.path.isfile(chain):
chain = open(chain, 'rb').read()
private_key = dev_args.get('private_key', '')
if not private_key:
private_key = None
if private_key and os.path.isfile(private_key):
private_key = open(private_key, 'rb').read()
if any((root, chain, private_key)):
override_name = dev_args.get('ssl_name_override', '')
if override_name:
self.log.info('Host override secure channel')
options.append(
(
'grpc.ssl_target_name_override',
override_name
),
)
self.log.info("Connecting secure channel")
channel_ssl_creds = grpc.ssl_channel_credentials(
root, private_key, chain
)
ssl_metadata = grpc.metadata_call_credentials(
CiscoAuthPlugin(
username,
password
)
)
channel_creds = grpc.composite_channel_credentials(
channel_ssl_creds, ssl_metadata
)
self.channel = grpc.secure_channel(
target, channel_creds, options
)
else:
self.channel = grpc.insecure_channel(target)
self.metadata = [
("username", username),
("password", password),
]
self.log.info("Connecting insecure channel")
self.service = proto.gnmi_pb2_grpc.gNMIStub(self.channel)
resp = self.capabilities()
if resp:
log.info('\ngNMI version: {0} supported encodings: {1}\n\n'.format(
resp.gNMI_version,
[proto.gnmi_pb2.Encoding.Name(i) for i in resp.supported_encodings]))
log.info(banner('gNMI CONNECTED'))
else:
log.info(banner('gNMI Capabilities not returned'))
self.disconnect()
raise gNMIException('Connection not successful')
[docs] def set(self, request):
"""Gnmi SET method.
Args:
request (proto.gnmi_pb2.SetRequest): gNMI SetRequest object
Returns:
proto.gnmi_pb2.SetResponse: gNMI SetResponse object
"""
return self.service.Set(request, metadata=self.metadata)
[docs] def get(self, request):
"""Gnmi GET method.
Args:
request (proto.gnmi_pb2.GetRequest): gNMI GetResponse object
Returns:
proto.gnmi_pb2.GetResponse: gNMI GetResponse object
"""
return self.service.Get(request, metadata=self.metadata)
[docs] def execute(self, cmd):
"""Helper method for backwards compatibility.
Args:
cmd (proto.gnmi_pb2.GetRequest): gNMI GetResponse object
Returns:
proto.gnmi_pb2.GetResponse: gNMI GetResponse object
"""
return self.get(cmd)
[docs] def capabilities(self):
"""Gnmi Capabilities method.
Returns:
proto.gnmi_pb2.CapabilityResponse: gNMI Capabilities object
"""
request = proto.gnmi_pb2.CapabilityRequest()
return self.service.Capabilities(request, metadata=self.metadata)
[docs] def subscribe(self, request_iter):
"""Gnmi Subscribe method.
Args:
request_iter (proto.gnmi_pb2.SubscribeRequest): gNMI SubscribeRequest object
Returns:
proto.gnmi_pb2.SubscribeResponse: gNMI SubscribeResponse object
"""
return self.service.Subscribe(request_iter, metadata=self.metadata)
[docs] def disconnect(self):
"""Disconnect from SSH device."""
if self.connected:
if self.channel:
self.channel.close()
del self.channel
def __enter__(self):
"""Establish a session using a Context Manager."""
if not self.connected:
self.connect()
return self
def __exit__(self, *args):
"""Gracefully close connection on Context Manager exit."""
self.disconnect()