Source code for afs.models

import json
import logging
import os
from io import BytesIO
import afs.utils as utils
import re
import base64
from uuid import uuid4
from afs.utils import (upload_file_to_blob, dowload_file_from_blob,
                        encrypt, decrypt)
from afs.get_env import AfsEnv


[docs]class models(AfsEnv): def __init__( self, target_endpoint=None, instance_id=None, auth_code=None, token=None ): """Connect to afs models service, user can connect to service by enviroment parameter. Another way is input when created. """ super(models, self).__init__(target_endpoint, instance_id, auth_code, token) self.entity_uri = "model_repositories" self.sub_entity_uri = "models" self.repo_id = None # Blob info self._blob_endpoint = self.blob_endpoint self._blob_accessKey = self.blob_accessKey self._blob_secretKey = self.blob_secretKey
[docs] def set_blob_credential( self, blob_endpoint, encode_blob_accessKey, encode_blob_secretKey, blob_record_id, bucket_name ): """Set blob credential when upload the big model. :param str blob_endpoint: blob endpoint :param str encode_blob_accessKey: blob accessKey encode with base64 :param str encode_blob_secretKey: blob secretKey encode with base64 :param str blob_record_id: MD5 with instance_id + '_' + accessKey :param str bucket_name: blob bucket name """ try: _blob_accessKey = str(base64.b64decode(encode_blob_accessKey), "utf-8") _blob_secretKey = str(base64.b64decode(encode_blob_secretKey), "utf-8") except Exception as e: raise ValueError( "encode_blob_accessKey, encode_blob_secretKey cannot be decoded." ) self._blob_endpoint = blob_endpoint self._blob_accessKey = _blob_accessKey self._blob_secretKey = _blob_secretKey self.blob_record_id = blob_record_id self.bucket_name = bucket_name
[docs] def get_model_repo_id(self, model_repository_name=None): """Get model repository by name. :param str model_repository_name: :return: str model repository id """ if model_repository_name: params = dict(name=model_repository_name) resp = self._get(params=params).json() if resp["resources"]: repo_id = resp["resources"][0]["uuid"] else: self.repo_id = None return None self.repo_id = repo_id return self.repo_id
[docs] def get_model_id(self, model_name=None, model_repository_name=None, last_one=True): """Get model id by model name. :param str model_name: model name. No need if last_one is true. :param str model_repository_name: model respository name where the model is. :param bool last_one: auto get the model_repository last one model :return: str model id """ if not model_repository_name: if not self.repo_id: raise ValueError("Please enter model_repository_name.") else: self.get_model_repo_id(model_repository_name=model_repository_name) if not self.repo_id: raise ValueError( "Model repository with name {} not found.".format( model_repository_name ) ) if model_name: params = dict(name=model_name) extra_paths = [self.repo_id, self.sub_entity_uri] resp = self._get(extra_paths=extra_paths, params=params) resp = resp.json() for resource in resp["resources"]: if resource['name'] == model_name: return resource['uuid'] else: extra_paths = [self.repo_id, self.sub_entity_uri] resp = self._get(extra_paths=extra_paths) resp = resp.json() resources = resp["resources"] if resources: return resources[0]["uuid"] return None
[docs] def download_model( self, save_path, model_repository_name=None, model_name=None, last_one=False ): """Download model from model repository to a file. :param str model_repository_name: The model name exists in the model repository :param str save_path: The path exist in the file system :param str model_name: Get the specific model file from the model reposiotry, if getting last one value for None. :param str last_one: Get the last uploading model from the model repository. """ if model_repository_name: self.get_model_repo_id(model_repository_name) if not self.repo_id: raise ValueError("There is no specific repo_id to download.") model_id = self.get_model_id( model_name=model_name, model_repository_name=model_repository_name, last_one=last_one, ) key = "models/{}/{}/{}".format(self.instance_id, self.repo_id, model_id) dowload_file_from_blob( self._blob_endpoint, self._blob_accessKey, self._blob_secretKey, self.bucket_name, key, save_path, ) return True
[docs] def upload_model( self, model_path, accuracy=None, loss=None, tags={}, extra_evaluation={}, feature_importance=None, coefficient=None, model_repository_name=None, model_name=None, encrypt_key='' ): """Upload model to model repository. (Support v2 API) :param str model_path: (required) model filepath :param float accuracy: (optional) model accuracy value, between 0-1 :param float loss: (optional) model loss value :param dict tags: (optional) tag from model :param dict extra_evaluation: (optional) other evaluation from model :param str model_name: (optional) Give model a name or a default name :param str model_repository_name: (optional) model_repository_name :param list feature_importance: (optional) feature_importance is the record how the features important in the model :param list coefficient: (optional) coefficient indicates the direction of the relationship between a predictor variable and the response :param str encrypt_key: (optional) If there is a encrypt_key, use the encrypt_key to encrypt the model :return: dict. the information of the upload model. """ if not isinstance(tags, dict) or not isinstance(extra_evaluation, dict): raise ValueError( "Type error, accuracy and loss are float, and tags and extra_evaluation are dict." ) # Evaluation_result info evaluation_result = {} if accuracy is not None: if not isinstance(accuracy, (float, int)): raise TypeError("Type error, accuracy is float.") if accuracy > 1.0 or accuracy < 0: raise ValueError("Accuracy value should be between 0-1") evaluation_result.update({"accuracy": accuracy}) if loss is not None: if not isinstance(loss, (float, int)): raise TypeError("Type error, loss is float.") evaluation_result.update({"loss": loss}) if not isinstance(model_path, str): raise TypeError("Type error, model_name cannot convert to string") if not os.path.isfile(model_path): raise IOError("File not found, model path is not exist.") # Check default repo_id if not self.repo_id: # Find model_repo_id from name if model_repository_name: self.get_model_repo_id(model_repository_name) # If not found, create one if not self.repo_id: self.repo_id = self.create_model_repo(model_repository_name) else: raise ValueError("Please enter model_repository_name") # Fetch tags apm_node pai_data_dir = os.getenv("PAI_DATA_DIR", None) if pai_data_dir: try: pai_data_dir = str(base64.b64decode(pai_data_dir), "utf-8") # double load to escape double quotes firehose = json.loads(pai_data_dir) if ("data" in firehose) and ("type" in firehose): data = firehose["data"] firehose_type = firehose["type"] if "machineIdList" in data and firehose_type == "apm-firehose": machineIdList = data.get("machineIdList", [None]).pop(0) if machineIdList: tags.update({"apm_node": str(machineIdList)}) except Exception as e: print( "PAI_DATA_DIR value is not valid json format for apm_node. Exception: {}, Value: {}".format(e, pai_data_dir) ) # record encrypy or not tags.update({"is_encrypted": bool(encrypt_key)}) if encrypt_key: data = None with open(model_path, 'rb') as f: data = f.read() data = encrypt(data, encrypt_key) with open(model_path, 'wb') as f: f.write(data) # Evaluation result evaluation_result.update(extra_evaluation) data = dict( tags=json.dumps(tags), evaluation_result=json.dumps(evaluation_result), ) if feature_importance: data.update({'feature_importance': json.dumps(feature_importance)}) if coefficient: data.update({'coefficient': json.dumps(coefficient)}) if self.afs_version >= '3.1.3': data.update( { 'dataset_id': os.getenv('dataset_id'), 'afs_target': os.getenv('afs_target'), } ) # model name if model_name: self._naming_rule(model_name) data.update({"name": model_name}) extra_paths = [self.repo_id, self.sub_entity_uri] file_size = os.path.getsize(model_path) # upload model file if file_size < (5*1024**3): if not ( self._blob_endpoint and self._blob_accessKey and self._blob_secretKey and self.bucket_name ): raise ValueError( "Blob information is not enough to put object to blob, {}, {}, {}, {}".format(self._blob_endpoint, self._blob_accessKey, self._blob_secretKey, self.bucket_name) ) # Create model metadata resp = self._create(data=data, extra_paths=extra_paths, form="data") model_id = resp.json()["uuid"] key = "models/{}/{}/{}".format(self.instance_id, self.repo_id, model_id) try: object_size = upload_file_to_blob( self._blob_endpoint, self._blob_accessKey, self._blob_secretKey, self.bucket_name, key, model_path, ) except ConnectionError as ex: # Delete model metadata if connection error extra_paths = [self.repo_id, self.sub_entity_uri, model_id] resp = self._del(extra_paths=extra_paths) raise ex # Update PUT Model File_info extra_paths = [ self.repo_id, self.sub_entity_uri, model_id, "file_info", ] put_payload = {"size": object_size, "blob_record_id": self.blob_record_id} resp = self._put(extra_paths=extra_paths, data=put_payload) else: raise Exception("The size of the file has exceeded the upper limit of 1G") if int(resp.status_code / 100) == 2: return resp.json() else: raise RuntimeError(resp.text)
[docs] def create_model_repo(self, model_repository_name): """ Create a new model repository. (Support v2 API) :param str model_repository_name: (optional)The name of model repository. :return: the new uuid of the repository """ if isinstance(model_repository_name, str): self._naming_rule(model_repository_name) else: raise TypeError("Repo name must be string") payload = dict(name=model_repository_name) resp = self._create(payload) self.repo_id = resp.json()["uuid"] return self.repo_id
[docs] def get_latest_model_info(self, model_repository_name=None): """ Get the latest model info, including created_at, tags, evaluation_result. (Support v2 API) :param model_repository_name: (optional)The name of model repository. :return: dict. the latest of model info in model repository. """ model_id = self.get_model_id( model_repository_name=model_repository_name, last_one=True ) if model_id: extra_paths = [self.repo_id, self.sub_entity_uri, model_id] resp = self._get(extra_paths=extra_paths) return resp.json() else: raise ValueError("Model not found.")
[docs] def get_model_info(self, model_name, model_repository_name=None): """Get model info, including created_at, tags, evaluation_result. (V2 API) :param model_name: model name :param model_repository_name: The name of model repository. :return: dict model info """ if not self.get_model_repo_id(model_repository_name=model_repository_name): raise ValueError("Model_repository not found.") model_id = self.get_model_id( model_name=model_name, model_repository_name=model_repository_name ) if model_id: extra_paths = [self.repo_id, self.sub_entity_uri, model_id] resp = self._get(extra_paths=extra_paths) else: raise ValueError("Model not found.") return resp.json()
[docs] def delete_model_repository(self, model_repository_name): """Delete model repository. :param model_repository_name: model repository name. :return: bool """ if not self.get_model_repo_id(model_repository_name): raise ValueError("Model_repository not found.") extra_paths = [self.repo_id] resp = self._del(extra_paths=extra_paths) if int(resp.status_code / 100) == 2: return True else: return False
[docs] def delete_model(self, model_name, model_repository_name=None): """Delete model. :param model_name: model name. :param model_repository_name: model repository name. :return: bool """ if not self.get_model_repo_id(model_repository_name=model_repository_name): raise ValueError("Model_repository not found.") model_id = self.get_model_id( model_name, model_repository_name=model_repository_name, last_one=False ) if not model_id: raise ValueError("Model not found.") extra_paths = [self.repo_id, self.sub_entity_uri, model_id] resp = self._del(extra_paths=extra_paths) if int(resp.status_code / 100) == 2: return True else: return False
[docs] def decrypt_model(self, model, decrypt_key): """Decrypt model. :param object model: the object of model :param str decrypt_key: use decrypt_key to decrypt the model :return: object """ return decrypt(model, decrypt_key)
def _create(self, data, files=None, extra_paths=[], form="json"): url = utils.urljoin( self.target_endpoint, "instances", self.instance_id, self.entity_uri, extra_paths=extra_paths, ) if not files: if form == "json": response = utils._check_response( self.session.post( url, params=dict(auth_code=self.auth_code), json=data, verify=False, ) ) elif form == "data": response = utils._check_response( self.session.post( url, params=dict(auth_code=self.auth_code), data=data, verify=False, ) ) else: if form == "json": response = utils._check_response( self.session.post( url, params=dict(auth_code=self.auth_code), json=data, files=files, verify=False, ) ) elif form == "data": response = utils._check_response( self.session.post( url, params=dict(auth_code=self.auth_code), data=data, files=files, verify=False, ) ) return response def _get(self, params={}, extra_paths=[]): url = utils.urljoin( self.target_endpoint, "instances", self.instance_id, self.entity_uri, extra_paths=extra_paths, ) get_params = {} get_params.update(dict(auth_code=self.auth_code)) get_params.update(params) response = utils._check_response( self.session.get(url, params=get_params, verify=False) ) return response def _del(self, params={}, extra_paths=[]): url = utils.urljoin( self.target_endpoint, "instances", self.instance_id, self.entity_uri, extra_paths=extra_paths, ) get_params = {} get_params.update(dict(auth_code=self.auth_code)) get_params.update(params) response = utils._check_response( self.session.delete(url, params=get_params, verify=False) ) return response def _naming_rule(self, name): limit = 72 if len(name) > limit or len(name) < 1: raise ValueError("Name length is upper limit 1-{} char".format((limit))) pattern = re.compile(r"(?!.*[^a-zA-Z0-9-_.]).{1,72}") match = pattern.match(name) if match is None: raise ValueError("Naming rule is only a-z, A-Z, 0-9, - and _ allowed.") return True def _put(self, data, extra_paths=[]): url = utils.urljoin( self.target_endpoint, "instances", self.instance_id, self.entity_uri, extra_paths=extra_paths, ) get_params = {} get_params.update(dict(auth_code=self.auth_code)) response = utils._check_response( self.session.put(url, params=get_params, json=data, verify=False) ) return response