Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Version 1.0 Candidate #291

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion core/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,17 @@ dependencies {
compile 'com.typesafe:config:1.3.0'
compile 'org.slf4j:slf4j-api:1.7.7'
compile 'joda-time:joda-time:2.5'
compile 'org.projectlombok:lombok:1.14.8'
compile 'org.projectlombok:lombok:1.16.8'
compile 'org.apache.commons:commons-lang3:3.4'
compile 'org.reflections:reflections:0.9.9'
compile 'it.unimi.dsi:fastutil:7.0.12'

//Validation
compile 'org.hibernate:hibernate-validator:5.2.4.Final'
compile 'javax.el:javax.el-api:2.2.4'
compile 'org.glassfish.web:javax.el:2.2.4'
compile 'org.hibernate:hibernate-validator-cdi:5.2.4.Final'

testCompile 'org.slf4j:slf4j-simple:1.7.7'
testCompile 'junit:junit:4.11'
}
45 changes: 45 additions & 0 deletions core/src/main/java/com/airbnb/aerosolve/core/Example.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
package com.airbnb.aerosolve.core;

import com.airbnb.aerosolve.core.models.AbstractModel;
import com.airbnb.aerosolve.core.features.MultiFamilyVector;
import com.airbnb.aerosolve.core.transforms.Transformer;
import java.util.Iterator;

/**
*
*/
public interface Example extends Iterable<MultiFamilyVector> {
MultiFamilyVector context();

MultiFamilyVector createVector();

MultiFamilyVector addToExample(MultiFamilyVector vector);

Example transform(Transformer transformer,
AbstractModel model);

default Example transform(Transformer transformer) {
return transform(transformer, null);
}

/**
* Returns the only MultiFamilyVector in this Example.
*
* If the Example contains nothing or more than one thing, this will throw an
* IllegalStateException.
*
* (Brad): Lots of code assumes the Example has only one item. Ideally, we should remove that
* assumption and then this method. This method helps us find code paths making that assumption.
*/
default MultiFamilyVector only() {
Iterator<MultiFamilyVector> iterator = iterator();
if (!iterator.hasNext()) {
throw new IllegalStateException("Called only() on an Example which contains nothing");
}
MultiFamilyVector result = iterator.next();
if (iterator.hasNext()) {
throw new IllegalStateException("Called only() on an Example containing more than one vector.");
}
return result;
}
}
123 changes: 123 additions & 0 deletions core/src/main/java/com/airbnb/aerosolve/core/FeatureVector.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
package com.airbnb.aerosolve.core;

import com.airbnb.aerosolve.core.features.Feature;
import com.airbnb.aerosolve.core.features.FeatureRegistry;
import com.airbnb.aerosolve.core.features.FeatureValue;
import com.airbnb.aerosolve.core.features.SimpleFeatureValue;
import com.google.common.base.Preconditions;
import com.google.common.collect.Iterables;
import com.google.common.collect.Sets;
import it.unimi.dsi.fastutil.objects.Object2DoubleMap;
import java.util.Iterator;
import java.util.List;
import java.util.NoSuchElementException;
import java.util.Set;
import java.util.concurrent.ThreadLocalRandom;
import java.util.function.Consumer;

/**
* When iterating a FeatureVector it may not always return a new copy of each value. If you want
* to save the values returned by the iterator, use the entry set instead.
*/
public interface FeatureVector extends Object2DoubleMap<Feature>, Iterable<FeatureValue> {
FeatureRegistry registry();

default void putString(Feature feature) {
put(feature, 1.0d);
}

default Iterator<FeatureValue> fastIterator() {
return iterator();
}

/**
* Use this if you intend to store the values. Don't use foreach.
*/
default Set<FeatureValue> featureValueEntrySet() {
return Sets.newHashSet(iterator());
}

@Override
default void forEach(Consumer<? super FeatureValue> action) {
Iterator<FeatureValue> iter = fastIterator();
while (iter.hasNext()) {
action.accept(iter.next());
}
}

default double get(String familyName, String featureName) {
Feature feature = registry().feature(familyName, featureName);
return getDouble(feature);
}

default boolean containsKey(String familyName, String featureName) {
Feature feature = registry().feature(familyName, featureName);
return containsKey(feature);
}

// TODO (Brad): This kind of breaks the abstraction. Do all Features have families?
default void put(String familyName, String featureName, double value) {
Feature feature = registry().feature(familyName, featureName);
put(feature, value);
}

default void putString(String familyName, String featureName) {
Feature feature = registry().feature(familyName, featureName);
putString(feature);
}

default Iterable<FeatureValue> withDropout(double dropout) {
return Iterables.filter(this, f -> ThreadLocalRandom.current().nextDouble() >= dropout);
}

default Iterable<FeatureValue> iterateMatching(List<Feature> features) {
Preconditions.checkNotNull(features, "Cannot iterate all features when features is null");
return () -> new FeatureSetIterator(this, features);
}

default double[] denseArray() {
double[] result = new double[size()];
int i = 0;
for (FeatureValue value : this) {
result[i] = value.value();
i++;
}
return result;
}

class FeatureSetIterator implements Iterator<FeatureValue> {

private final FeatureVector vector;
private final List<Feature> features;
private SimpleFeatureValue entry = SimpleFeatureValue.of(null, 0.0d);
private int index = -1;
private int nextIndex = 0;

public FeatureSetIterator(FeatureVector vector, List<Feature> features) {
this.vector = vector;
this.features = features;
}

@Override
public boolean hasNext() {
if (index >= nextIndex) {
nextIndex++;
while(nextIndex < features.size() && !vector.containsKey(features.get(nextIndex))) {
nextIndex++;
}
}
return nextIndex < features.size();
}

@Override
public FeatureValue next() {
if (!hasNext()) {
throw new NoSuchElementException();
}
entry.feature(features.get(nextIndex));
entry.value(vector.getDouble(entry.feature()));
index = nextIndex;
return entry;
}
}
}
179 changes: 179 additions & 0 deletions core/src/main/java/com/airbnb/aerosolve/core/features/BasicFamily.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
package com.airbnb.aerosolve.core.features;

import com.google.common.base.Preconditions;
import com.google.common.primitives.Ints;
import it.unimi.dsi.fastutil.objects.Object2ObjectOpenHashMap;
import it.unimi.dsi.fastutil.objects.Reference2ObjectMap;
import it.unimi.dsi.fastutil.objects.Reference2ObjectOpenHashMap;
import java.io.Serializable;
import java.util.Arrays;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;
import lombok.Getter;
import lombok.Synchronized;
import lombok.Value;
import lombok.experimental.Accessors;

/**
*
*/
@Accessors(fluent = true, chain = true)
public class BasicFamily implements Family, Serializable {

private final int hashCode;
private Map<String, Feature> featuresByName;
@Getter
private final String name;
@Getter
private final int index;
private final AtomicInteger featureCount;
private Feature[] featuresByIndex;
@Getter
private boolean isDense = false;
private Map<Feature, Map<Feature, FeatureJoin[]>> crosses;

BasicFamily(String name, int index) {
Preconditions.checkNotNull(name, "All Families must have a name");
this.name = name;
this.index = index;
this.featureCount = new AtomicInteger(0);
this.crosses = new Object2ObjectOpenHashMap<>();
this.hashCode = name.hashCode();
}

@Override
public void markDense() {
if (featuresByName != null && !featuresByName.isEmpty()) {
throw new IllegalStateException("Tried to make a family dense but it already has features"
+ " defined by name. Some code probably thinks it's sparse.");
}
isDense = true;
}

@Override
public Feature feature(String featureName) {
if (isDense) {
// Note that it's not recommended to use this method if the type is DENSE.
// Just call feature(int). It will be faster.
Integer index = Ints.tryParse(featureName);
if (index == null) {
throw new IllegalArgumentException(String.format(
"Could not parse %s to a valid integer for lookup in a dense family: %s. Dense families "
+ "do not have names for each feature.", featureName, name()));
}
return feature(index);
}
if (featuresByName == null) {
featuresByName = new ConcurrentHashMap<>(allocationSize());
}
Feature feature = featuresByName.computeIfAbsent(
featureName,
innerName -> new Feature(this, innerName, featureCount.getAndIncrement())
);
if (featuresByIndex == null || feature.index() >= featuresByIndex.length) {
resizeFeaturesByIndex(feature.index());
}
if (featuresByIndex[feature.index()] == null) {
featuresByIndex[feature.index()] = feature;
}
return feature;
}

@Override
public Feature feature(int index) {
if (featuresByIndex == null || index >= featuresByIndex.length) {
if (isDense) {
resizeFeaturesByIndex(index);
} else {
return null;
}
}
if (isDense && featuresByIndex[index] == null) {
featuresByIndex[index] = new Feature(this, String.valueOf(index), index);
}
return featuresByIndex[index];
}

@Synchronized
private void resizeFeaturesByIndex(int index) {
if (featuresByIndex == null) {
featuresByIndex = new Feature[Family.allocationSize(index + 1)];
return;
}
// We check outside and inside this method because it's synchronized and this can change
// between when we intend to enter and when we actually enter.
if (index < featuresByIndex.length) {
return;
}
// Need to resize.
int length = featuresByIndex.length;
while (index >= length) {
length = length * 2;
}
featuresByIndex = Arrays.copyOf(featuresByIndex, length);
}

@Override
public Feature cross(Feature left, Feature right, String separator) {
Map<Feature, FeatureJoin[]> rightMap = crosses.get(left);
if (rightMap == null) {
rightMap = new Object2ObjectOpenHashMap<>();
crosses.put(left, rightMap);
}
FeatureJoin[] joinArr = rightMap.get(right);
if (joinArr == null) {
joinArr = new FeatureJoin[2];
rightMap.put(right, joinArr);
}
int i;
for (i = 0; i < joinArr.length; i++) {
FeatureJoin join = joinArr[i];
if (join == null) {
break;
}
if (join.separator().equals(separator)) {
return join.feature();
}
}
Feature feature = feature(left.name() + separator + right.name());
FeatureJoin newJoin = new FeatureJoin(feature, separator);
if (i >= joinArr.length) {
joinArr = Arrays.copyOf(joinArr, joinArr.length * 2);
rightMap.put(right, joinArr);
}
joinArr[i] = newJoin;
return feature;
}

public int size() {
return featureCount.get();
}

@Override
public int hashCode() {
return hashCode;
}

@Override
public boolean equals(Object obj) {
if (obj == this) {
return true;
}
if (!(obj instanceof Family)) {
return false;
}
return name.equals(((Family) obj).name());
}

@Override
public String toString() {
return name;
}

@Value
private static class FeatureJoin {
private final Feature feature;
private final String separator;
}
}
Loading