-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathLDA.hs
More file actions
154 lines (137 loc) · 6.27 KB
/
LDA.hs
File metadata and controls
154 lines (137 loc) · 6.27 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE OverloadedLabels #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE QuantifiedConstraints #-}
{-# LANGUAGE Rank2Types #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE Trustworthy #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# OPTIONS_GHC -Wno-unrecognised-pragmas #-}
{-# HLINT ignore "Use camelCase" #-}
{-# LANGUAGE TypeApplications #-}
{- | A [Latent Dirichlet Allocation (LDA)](https://en.wikipedia.org/wiki/Latent_Dirichlet_allocation) model
(or topic model) for learning the distribution over words and topics in a text document.
-}
module LDA where
import Control.Monad (replicateM)
import Data.Kind (Constraint)
import Env (Assign ((:=)), Env, Observable, Observables,
get, nil, (<:>))
import Inference.MH as MH (mhRaw)
import Inference.SIM as SIM (simulate)
import Model (Model, categorical, dirichlet, discrete')
import Sampler (Sampler)
{- | An LDA environment.
Assuming 1 document with 2 topics and a vocabulary of 4 words,
the parameters of the model environment would have the following shape:
θ would be [[prob_topic_1, prob_topic_2] -- probabilities of topics in document 1
]
φ would be [[prob_word_1, prob_word_2, prob_word_3, prob_word_4] -- probabilities of words in topic 1
[prob_word_1, prob_word_2, prob_word_3, prob_word_4] -- probabilities of words in topic 2
]
-}
type TopicEnv =
'[ "θ" ':= [Double], -- ^ probabilities of each topic in a document
"φ" ':= [Double], -- ^ probabilities of each word in a topic
"w" ':= String -- ^ word
]
-- | Prior distribution for topics in a document
docTopicPrior :: Observable env "θ" [Double]
-- | number of topics
=> Int
-- | probability of each topic
-> Model env sig m [Double]
docTopicPrior n_topics = dirichlet (replicate n_topics 1) #θ
-- | Prior distribution for words in a topic
topicWordPrior :: Observable env "φ" [Double]
-- | vocabulary
=> [String]
-- | probability of each word
-> Model env sig m [Double]
topicWordPrior vocab
= dirichlet (replicate (length vocab) 1) #φ
-- | A distribution generating words according to their probabilities
wordDist :: Observable env "w" String
-- | vocabulary
=> [String]
-- | probability of each word
-> [Double]
-- | generated word
-> Model env sig m String
wordDist vocab ps =
categorical (zip vocab ps) #w
-- | Distribution over the topics in a document, over the distribution of words in a topic
topicModel :: (Observables env '["φ", "θ"] [Double],
Observable env "w" String)
-- | vocabulary
=> [String]
-- | number of topics
-> Int
-- | number of words
-> Int
-- | generated words
-> Model env sig m [String]
topicModel vocab n_topics n_words = do
-- Generate distribution over words for each topic
topic_word_ps <- replicateM n_topics $ topicWordPrior vocab
-- Generate distribution over topics for a given document
doc_topic_ps <- docTopicPrior n_topics
replicateM n_words (do z <- discrete' doc_topic_ps
let word_ps = topic_word_ps !! z
wordDist vocab word_ps)
-- | Topic distribution over many topics
topicModels :: (Observables env '["φ", "θ"] [Double],
Observable env "w" String)
-- | vocabulary
=> [String]
-- | number of topics
-> Int
-- | number of words for each document
-> [Int]
-- | generated words for each document
-> Model env sig m [[String]]
topicModels vocab n_topics doc_words = do
mapM (topicModel vocab n_topics) doc_words
-- | Example possible vocabulary
vocab :: [String]
vocab = ["DNA", "evolution", "parsing", "phonology"]
-- | Simulating from topic model
simLDA :: Sampler [String]
simLDA = do
-- Specify model inputs
let n_words = 100
n_topics = 2
-- Specify model environment
env_in :: Env TopicEnv
env_in = #θ := [[0.5, 0.5]] <:>
#φ := [[0.12491280814569208,1.9941599739151505e-2,0.5385152817942926,0.3166303103208638],
[1.72605174564027e-2,2.9475900240868515e-2,9.906011619752661e-2,0.8542034661052021]] <:>
#w := [] <:> nil
-- Simulate from topic model
(words, env_out) <- SIM.simulate env_in (topicModel vocab n_topics n_words)
return words
-- | Example document of words
topic_data :: [String]
topic_data = ["DNA","evolution","DNA","evolution","DNA","evolution","DNA","evolution","DNA","evolution", "parsing", "phonology", "DNA","evolution", "DNA", "parsing", "evolution","phonology", "evolution", "DNA","DNA","evolution","DNA","evolution","DNA","evolution","DNA","evolution","DNA","evolution", "parsing", "phonology", "DNA","evolution", "DNA", "parsing", "evolution","phonology", "evolution", "DNA","DNA","evolution","DNA","evolution","DNA","evolution","DNA","evolution","DNA","evolution", "parsing", "phonology", "DNA","evolution", "DNA", "parsing", "evolution","phonology", "evolution", "DNA","DNA","evolution","DNA","evolution","DNA","evolution","DNA","evolution","DNA","evolution", "parsing", "phonology", "DNA","evolution", "DNA", "parsing", "evolution","phonology", "evolution", "DNA","DNA","evolution","DNA","evolution","DNA","evolution","DNA","evolution","DNA","evolution", "parsing", "phonology", "DNA","evolution", "DNA", "parsing", "evolution","phonology", "evolution", "DNA"]
-- | MH inference from topic model
mhLDA :: Sampler ([[Double]], [[Double]])
mhLDA = do
-- Specify model inputs
let n_words = 100
n_topics = 2
-- Specify model environment
env_mh_in :: Env TopicEnv
env_mh_in = #θ := [] <:> #φ := [] <:> #w := topic_data <:> nil
-- Run MH for 500 iterations
env_mh_outs <- MH.mhRaw 500 (topicModel vocab n_topics n_words) env_mh_in nil (#φ <:> #θ <:> nil)
-- Draw the most recent sampled parameters from the MH trace
let env_pred = head env_mh_outs
θs = get #θ env_pred
φs = get #φ env_pred
return (θs, φs)