Skip to content

Commit 9a790e6

Browse files
authored
Merge pull request beehive-lab#65 from orionpapadakis/refactor/weight_abstraction
Weight Abstractions
2 parents 2376059 + 83d4f44 commit 9a790e6

80 files changed

Lines changed: 1010 additions & 1471 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

src/main/java/org/beehive/gpullama3/core/types/Pair.java renamed to src/main/java/org/beehive/gpullama3/auxiliary/Pair.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
package org.beehive.gpullama3.core.types;
1+
package org.beehive.gpullama3.auxiliary;
22

33
public record Pair<First, Second>(First first, Second second) {
44
}

src/main/java/org/beehive/gpullama3/core/model/tensor/Q8_0QuantizedTensor.java

Lines changed: 0 additions & 177 deletions
This file was deleted.

src/main/java/org/beehive/gpullama3/inference/InferenceCore.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
package org.beehive.gpullama3.inference;
22

33
import org.beehive.gpullama3.auxiliary.Parallel;
4-
import org.beehive.gpullama3.core.model.tensor.FloatTensor;
4+
import org.beehive.gpullama3.tensor.standard.FloatTensor;
55
import org.beehive.gpullama3.inference.state.Phi3State;
66
import org.beehive.gpullama3.inference.state.State;
77
import org.beehive.gpullama3.inference.weights.standard.Phi3StandardWeights;
@@ -583,7 +583,7 @@ public static FloatArray forwardTornadoVM(Model model, State state, int token, i
583583
final Configuration configuration = model.configuration();
584584
final TornadoWeights weights = (TornadoWeights) model.weights();
585585

586-
MemorySegment.copy(weights.getTokenEmbeddingTable().getSegment(), (long) token * configuration.dim() * Float.BYTES, state.wrapX.getSegment(), 0, configuration.dim() * Float.BYTES);
586+
MemorySegment.copy(weights.getTokenEmbeddingTable().asFloatArray().getSegment(), (long) token * configuration.dim() * Float.BYTES, state.wrapX.getSegment(), 0, configuration.dim() * Float.BYTES);
587587

588588
return tornadoVMMasterPlan.tornadoVMForwardExecuteLayered(position);
589589
}

src/main/java/org/beehive/gpullama3/inference/operation/RoPE.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
package org.beehive.gpullama3.inference.operation;
22

3-
import org.beehive.gpullama3.core.types.Pair;
3+
import org.beehive.gpullama3.auxiliary.Pair;
44

55
public final class RoPE {
66
public static Pair<float[], float[]> precomputeFreqsCis(int contextLength, int headSize, double theta,

src/main/java/org/beehive/gpullama3/inference/sampler/CategoricalSampler.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
package org.beehive.gpullama3.inference.sampler;
22

3-
import org.beehive.gpullama3.core.model.tensor.FloatTensor;
3+
import org.beehive.gpullama3.tensor.standard.FloatTensor;
44
import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
55

66
import java.util.random.RandomGenerator;

src/main/java/org/beehive/gpullama3/inference/sampler/Sampler.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
package org.beehive.gpullama3.inference.sampler;
22

33
import org.beehive.gpullama3.Options;
4-
import org.beehive.gpullama3.core.model.tensor.FloatTensor;
4+
import org.beehive.gpullama3.tensor.standard.FloatTensor;
55
import org.beehive.gpullama3.model.Model;
66
import org.beehive.gpullama3.tornadovm.utils.FloatArrayUtils;
77
import uk.ac.manchester.tornado.api.types.arrays.FloatArray;

src/main/java/org/beehive/gpullama3/inference/sampler/ToppSampler.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
package org.beehive.gpullama3.inference.sampler;
22

3-
import org.beehive.gpullama3.core.model.tensor.FloatTensor;
3+
import org.beehive.gpullama3.tensor.standard.FloatTensor;
44
import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
55

66
import java.util.Comparator;

src/main/java/org/beehive/gpullama3/inference/state/LlamaState.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
package org.beehive.gpullama3.inference.state;
22

3-
import org.beehive.gpullama3.core.model.tensor.ArrayFloatTensor;
4-
import org.beehive.gpullama3.core.model.tensor.FloatTensor;
3+
import org.beehive.gpullama3.tensor.standard.ArrayFloatTensor;
4+
import org.beehive.gpullama3.tensor.standard.FloatTensor;
55
import org.beehive.gpullama3.model.Configuration;
66
import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
77
import uk.ac.manchester.tornado.api.types.arrays.IntArray;

src/main/java/org/beehive/gpullama3/inference/state/Phi3State.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
package org.beehive.gpullama3.inference.state;
22

3-
import org.beehive.gpullama3.core.model.tensor.ArrayFloatTensor;
4-
import org.beehive.gpullama3.core.model.tensor.FloatTensor;
3+
import org.beehive.gpullama3.tensor.standard.ArrayFloatTensor;
4+
import org.beehive.gpullama3.tensor.standard.FloatTensor;
55
import org.beehive.gpullama3.model.Configuration;
66
import org.beehive.gpullama3.model.phi3.Phi3Configuration;
77
import uk.ac.manchester.tornado.api.types.arrays.FloatArray;

src/main/java/org/beehive/gpullama3/inference/state/Qwen2State.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
package org.beehive.gpullama3.inference.state;
22

3-
import org.beehive.gpullama3.core.model.tensor.ArrayFloatTensor;
4-
import org.beehive.gpullama3.core.model.tensor.FloatTensor;
3+
import org.beehive.gpullama3.tensor.standard.ArrayFloatTensor;
4+
import org.beehive.gpullama3.tensor.standard.FloatTensor;
55
import org.beehive.gpullama3.model.Configuration;
66
import org.beehive.gpullama3.model.qwen2.Qwen2Configuration;
77
import uk.ac.manchester.tornado.api.types.arrays.FloatArray;

0 commit comments

Comments
 (0)