|
| 1 | +""" |
| 2 | +Data Structure: PRISM Rule-Based Classifier |
| 3 | +
|
| 4 | +Description: |
| 5 | + PRISM is a rule-based machine learning algorithm used for classification. |
| 6 | + It generates simple if-then rules for each class by analyzing attribute-value pairs. |
| 7 | +
|
| 8 | +Use Case: |
| 9 | + Useful in applications requiring interpretable classification rules, |
| 10 | + such as medical diagnosis, decision support systems, or expert systems. |
| 11 | +
|
| 12 | +Time Complexity: |
| 13 | + O(N * M * V), where |
| 14 | + - N = number of instances |
| 15 | + - M = number of attributes |
| 16 | + - V = average number of attribute values |
| 17 | +
|
| 18 | +Space Complexity: |
| 19 | + O(R), where R = number of generated rules. |
| 20 | +""" |
| 21 | + |
| 22 | +from collections import Counter |
| 23 | + |
| 24 | +class PRISM: |
| 25 | + def __init__(self, data, attributes, class_label): |
| 26 | + """ |
| 27 | + Initialize the PRISM classifier. |
| 28 | + |
| 29 | + Args: |
| 30 | + data (list of dict): Dataset as list of instances (dict of attribute: value). |
| 31 | + attributes (list): List of attribute names. |
| 32 | + class_label (str): Name of the class attribute. |
| 33 | + """ |
| 34 | + self.data = data |
| 35 | + self.attributes = attributes |
| 36 | + self.class_label = class_label |
| 37 | + self.rules = [] |
| 38 | + |
| 39 | + def train(self): |
| 40 | + """Generate rules for each class.""" |
| 41 | + classes = set([d[self.class_label] for d in self.data]) |
| 42 | + for cls in classes: |
| 43 | + examples = [d for d in self.data if d[self.class_label] == cls] |
| 44 | + while examples: |
| 45 | + rule = self._generate_rule(examples, cls) |
| 46 | + self.rules.append(rule) |
| 47 | + # Remove examples covered by this rule |
| 48 | + examples = [ex for ex in examples if not self._matches_rule(ex, rule)] |
| 49 | + |
| 50 | + def _generate_rule(self, examples, cls): |
| 51 | + """Generate a single rule for a class.""" |
| 52 | + rule = {} |
| 53 | + remaining_attrs = self.attributes.copy() |
| 54 | + while True: |
| 55 | + best_attr, best_value, best_accuracy = None, None, 0 |
| 56 | + for attr in remaining_attrs: |
| 57 | + values = set([ex[attr] for ex in examples]) |
| 58 | + for val in values: |
| 59 | + covered = [ex for ex in examples if ex[attr] == val] |
| 60 | + accuracy = sum(1 for ex in covered if ex[self.class_label] == cls) / len(covered) |
| 61 | + if accuracy > best_accuracy: |
| 62 | + best_attr, best_value, best_accuracy = attr, val, accuracy |
| 63 | + if best_attr is None: |
| 64 | + break |
| 65 | + rule[best_attr] = best_value |
| 66 | + remaining_attrs.remove(best_attr) |
| 67 | + # Filter examples covered by this condition |
| 68 | + examples = [ex for ex in examples if ex[best_attr] == best_value] |
| 69 | + if all(ex[self.class_label] == cls for ex in examples): |
| 70 | + break |
| 71 | + rule['class'] = cls |
| 72 | + return rule |
| 73 | + |
| 74 | + def _matches_rule(self, instance, rule): |
| 75 | + """Check if an instance matches a rule.""" |
| 76 | + for attr, val in rule.items(): |
| 77 | + if attr == 'class': |
| 78 | + continue |
| 79 | + if instance[attr] != val: |
| 80 | + return False |
| 81 | + return True |
| 82 | + |
| 83 | + def predict(self, instance): |
| 84 | + """Predict class label for a single instance.""" |
| 85 | + for rule in self.rules: |
| 86 | + if self._matches_rule(instance, rule): |
| 87 | + return rule['class'] |
| 88 | + return None |
| 89 | + |
| 90 | + |
| 91 | +def main(): |
| 92 | + """Test PRISM implementation.""" |
| 93 | + data = [ |
| 94 | + {'Outlook':'Sunny', 'Temp':'Hot', 'Humidity':'High', 'Wind':'Weak', 'PlayTennis':'No'}, |
| 95 | + {'Outlook':'Sunny', 'Temp':'Hot', 'Humidity':'High', 'Wind':'Strong', 'PlayTennis':'No'}, |
| 96 | + {'Outlook':'Overcast', 'Temp':'Hot', 'Humidity':'High', 'Wind':'Weak', 'PlayTennis':'Yes'}, |
| 97 | + {'Outlook':'Rain', 'Temp':'Mild', 'Humidity':'High', 'Wind':'Weak', 'PlayTennis':'Yes'} |
| 98 | + ] |
| 99 | + attributes = ['Outlook', 'Temp', 'Humidity', 'Wind'] |
| 100 | + prism = PRISM(data, attributes, 'PlayTennis') |
| 101 | + prism.train() |
| 102 | + print("Generated Rules:", prism.rules) |
| 103 | + test_instance = {'Outlook':'Rain', 'Temp':'Mild', 'Humidity':'High', 'Wind':'Weak'} |
| 104 | + print("Prediction:", prism.predict(test_instance)) |
| 105 | + |
| 106 | + |
| 107 | +if __name__ == "__main__": |
| 108 | + main() |
0 commit comments