# -*- coding: utf-8 -*-
from __future__ import (absolute_import, division,
print_function, unicode_literals)
import jubatus
import jubatus.embedded
import uuid
from .base import BaseDataset, BaseService, GenericConfig, GenericSchema
from .compat import unicode_t
from .loader.array import ArrayLoader, ZipArrayLoader
from .loader.chain import MergeChainLoader
from .loader.sparse import SparseMatrixLoader
[docs]class Schema(GenericSchema):
"""Schema for Nearest Neighbor service."""
ID = 'i'
[docs] def __init__(self, mapping, fallback=None):
self._id_key = self._get_unique_mapping(mapping, fallback, self.ID,
'ID', True)
super(Schema, self).__init__(mapping, fallback)
[docs]class Dataset(BaseDataset):
"""Dataset for Nearest Neighbor service."""
@classmethod
def _predict(cls, row):
return Schema.predict(row, False)
@classmethod
def _from_loader(cls, data_loader, ids, static):
if ids is None:
loader = data_loader
schema = Schema({}, Schema.NUMBER)
else:
id_loader = ZipArrayLoader(_id=ids)
loader = MergeChainLoader(data_loader, id_loader)
schema = Schema({'_id': Schema.ID}, Schema.NUMBER)
return Dataset(loader, schema, static)
[docs] @classmethod
def from_data(cls, data, ids=None, feature_names=None, static=True):
"""Converts two arrays or a sparse matrix data and its associated
id array to Dataset.
Parameters
----------
data : array or scipy 2-D sparse matrix of shape
[n_samples, n_features]
ids : array of shape [n_samples], optional
feature_names : array of shape [n_features], optional
"""
if hasattr(data, 'todense'):
return cls.from_matrix(data, ids, feature_names, static)
else:
return cls.from_array(data, ids, feature_names, static)
[docs] @classmethod
def from_array(cls, data, ids=None, feature_names=None, static=True):
"""Converts two arrays (data and its associated targets) to Dataset.
Parameters
----------
data : array of shape [n_samples, n_features]
ids : array of shape [n_samples], optional
feature_names : array of shape [n_features], optional
"""
data_loader = ArrayLoader(data, feature_names)
return cls._from_loader(data_loader, ids, static)
[docs] @classmethod
def from_matrix(cls, data, ids=None, feature_names=None, static=True):
"""Converts a sparse matrix data and its associated target array
to Dataset.
Parameters
----------
data : scipy 2-D sparse matrix of shape [n_samples, n_features]
ids : array of shape [n_samples], optional
feature_names : array of shape [n_features], optional
"""
data_loader = SparseMatrixLoader(data, feature_names)
return cls._from_loader(data_loader, ids, static)
[docs] def get_ids(self):
"""
Returns labels of each record in the dataset.
"""
if not self._static:
raise RuntimeError('non-static datasets cannot fetch list of ids')
for (idx, (row_id, d)) in self:
yield row_id
[docs]class NearestNeighbor(BaseService):
"""Nearest Neighbor service."""
[docs] @classmethod
def name(cls):
return 'nearest_neighbor'
@classmethod
def _client_class(cls):
return jubatus.nearest_neighbor.client.NearestNeighbor
@classmethod
def _embedded_class(cls):
return jubatus.embedded.NearestNeighbor
[docs] def set_row(self, dataset):
"""Updates the row whose id is id with given row.
If the row with the same id already exists, the row is overwritten with
row (note that this behavior is different from that of recommender).
Otherwise, new row entry will be created.
If the server that manages the row and the server that received
this RPC request are same, this operation is reflected instantly.
If not, update operation is reflected after mix."""
cli = self._client()
for (idx, (row_id, d)) in dataset:
if row_id is None:
raise RuntimeError('dataset must have id.')
result = cli.set_row(row_id, d)
yield (idx, row_id, result)
[docs] def neighbor_row_from_id(self, dataset, size=10):
"""Returns size rows (at maximum) that have most similar datum
to id and their distance values."""
cli = self._client()
for (idx, (row_id, _)) in dataset:
if row_id is None:
raise RuntimeError('each data point must have its id.')
result = cli.neighbor_row_from_id(row_id, size)
yield (idx, row_id, result)
[docs] def neighbor_row_from_datum(self, dataset, size=10):
"""Returns size rows (at maximum) of which datum are most similar to
query and their distance values."""
cli = self._client()
for (idx, (row_id, d)) in dataset:
if row_id is None:
raise RuntimeError('each data point must have its id.')
result = cli.neighbor_row_from_datum(d, size)
yield (idx, row_id, result)
[docs] def similar_row_from_id(self, dataset, size=10):
"""Returns ret_num rows (at maximum) that have most similar datum to id
and their similarity values.
"""
cli = self._client()
for (idx, (row_id, _)) in dataset:
if row_id is None:
raise RuntimeError(
'Non ID-based datasets must use `similar_row_from_datum`')
result = cli.similar_row_from_id(row_id, size)
yield (idx, row_id, result)
[docs] def similar_row_from_datum(self, dataset, size=10):
"""Returns ret_num rows (at maximum) of which datum are most similar
to query and their similarity values.
"""
cli = self._client()
for (idx, (row_id, d)) in dataset:
result = cli.similar_row_from_datum(d, size)
yield (idx, row_id, result)
[docs] def get_all_rows(self):
"""Returns the list of all row IDs."""
cli = self._client()
return cli.get_all_rows()
[docs]class Config(GenericConfig):
"""Configuration to run Nearest Neighbor service."""
[docs] @classmethod
def methods(cls):
return ['lsh', 'minhash', 'euclid_lsh']
@classmethod
def _default_method(cls):
return 'lsh'
@classmethod
def _default_parameter(cls, method):
if method not in Config.methods():
raise RuntimeError('unknown method: {0}'.format(method))
return {
'threads': 1,
'hash_num': 128
}