# -*- coding: utf-8 -*-
from __future__ import (absolute_import, division,
print_function, unicode_literals)
import jubatus
import jubatus.embedded
import uuid
from .base import BaseDataset, BaseService, GenericConfig, GenericSchema
from .compat import unicode_t
from .loader.array import ArrayLoader, ZipArrayLoader
from .loader.chain import MergeChainLoader
from .loader.sparse import SparseMatrixLoader
[ドキュメント]class Schema(GenericSchema):
"""Schema for Nearest Neighbor service."""
ID = 'i'
[ドキュメント] def __init__(self, mapping, fallback=None):
self._id_key = self._get_unique_mapping(mapping, fallback, self.ID,
'ID', True)
super(Schema, self).__init__(mapping, fallback)
[ドキュメント]class Dataset(BaseDataset):
"""Dataset for Nearest Neighbor service."""
@classmethod
def _predict(cls, row):
return Schema.predict(row, False)
@classmethod
def _from_loader(cls, data_loader, ids, static):
if ids is None:
loader = data_loader
schema = Schema({}, Schema.NUMBER)
else:
id_loader = ZipArrayLoader(_id=ids)
loader = MergeChainLoader(data_loader, id_loader)
schema = Schema({'_id': Schema.ID}, Schema.NUMBER)
return Dataset(loader, schema, static)
[ドキュメント] @classmethod
def from_data(cls, data, ids=None, feature_names=None, static=True):
"""Converts two arrays or a sparse matrix data and its associated
id array to Dataset.
Parameters
----------
data : array or scipy 2-D sparse matrix of shape
[n_samples, n_features]
ids : array of shape [n_samples], optional
feature_names : array of shape [n_features], optional
"""
if hasattr(data, 'todense'):
return cls.from_matrix(data, ids, feature_names, static)
else:
return cls.from_array(data, ids, feature_names, static)
[ドキュメント] @classmethod
def from_array(cls, data, ids=None, feature_names=None, static=True):
"""Converts two arrays (data and its associated targets) to Dataset.
Parameters
----------
data : array of shape [n_samples, n_features]
ids : array of shape [n_samples], optional
feature_names : array of shape [n_features], optional
"""
data_loader = ArrayLoader(data, feature_names)
return cls._from_loader(data_loader, ids, static)
[ドキュメント] @classmethod
def from_matrix(cls, data, ids=None, feature_names=None, static=True):
"""Converts a sparse matrix data and its associated target array
to Dataset.
Parameters
----------
data : scipy 2-D sparse matrix of shape [n_samples, n_features]
ids : array of shape [n_samples], optional
feature_names : array of shape [n_features], optional
"""
data_loader = SparseMatrixLoader(data, feature_names)
return cls._from_loader(data_loader, ids, static)
[ドキュメント] def get_ids(self):
"""
Returns labels of each record in the dataset.
"""
if not self._static:
raise RuntimeError('non-static datasets cannot fetch list of ids')
for (idx, (row_id, d)) in self:
yield row_id
[ドキュメント]class NearestNeighbor(BaseService):
"""Nearest Neighbor service."""
[ドキュメント] @classmethod
def name(cls):
return 'nearest_neighbor'
@classmethod
def _client_class(cls):
return jubatus.nearest_neighbor.client.NearestNeighbor
@classmethod
def _embedded_class(cls):
return jubatus.embedded.NearestNeighbor
[ドキュメント] def set_row(self, dataset):
"""Updates the row whose id is id with given row.
If the row with the same id already exists, the row is overwritten with
row (note that this behavior is different from that of recommender).
Otherwise, new row entry will be created.
If the server that manages the row and the server that received
this RPC request are same, this operation is reflected instantly.
If not, update operation is reflected after mix."""
cli = self._client()
for (idx, (row_id, d)) in dataset:
if row_id is None:
raise RuntimeError('dataset must have id.')
result = cli.set_row(row_id, d)
yield (idx, row_id, result)
[ドキュメント] def neighbor_row_from_id(self, dataset, size=10):
"""Returns size rows (at maximum) that have most similar datum
to id and their distance values."""
cli = self._client()
for (idx, (row_id, _)) in dataset:
if row_id is None:
raise RuntimeError('each data point must have its id.')
result = cli.neighbor_row_from_id(row_id, size)
yield (idx, row_id, result)
[ドキュメント] def neighbor_row_from_datum(self, dataset, size=10):
"""Returns size rows (at maximum) of which datum are most similar to
query and their distance values."""
cli = self._client()
for (idx, (row_id, d)) in dataset:
if row_id is None:
raise RuntimeError('each data point must have its id.')
result = cli.neighbor_row_from_datum(d, size)
yield (idx, row_id, result)
[ドキュメント] def similar_row_from_id(self, dataset, size=10):
"""Returns ret_num rows (at maximum) that have most similar datum to id
and their similarity values.
"""
cli = self._client()
for (idx, (row_id, _)) in dataset:
if row_id is None:
raise RuntimeError(
'Non ID-based datasets must use `similar_row_from_datum`')
result = cli.similar_row_from_id(row_id, size)
yield (idx, row_id, result)
[ドキュメント] def similar_row_from_datum(self, dataset, size=10):
"""Returns ret_num rows (at maximum) of which datum are most similar
to query and their similarity values.
"""
cli = self._client()
for (idx, (row_id, d)) in dataset:
result = cli.similar_row_from_datum(d, size)
yield (idx, row_id, result)
[ドキュメント] def get_all_rows(self):
"""Returns the list of all row IDs."""
cli = self._client()
return cli.get_all_rows()
[ドキュメント]class Config(GenericConfig):
"""Configuration to run Nearest Neighbor service."""
[ドキュメント] @classmethod
def methods(cls):
return ['lsh', 'minhash', 'euclid_lsh']
@classmethod
def _default_method(cls):
return 'lsh'
@classmethod
def _default_parameter(cls, method):
if method not in Config.methods():
raise RuntimeError('unknown method: {0}'.format(method))
return {
'threads': 1,
'hash_num': 128
}